summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/cudnn.rs6
1 files changed, 3 insertions, 3 deletions
diff --git a/candle-core/src/cudnn.rs b/candle-core/src/cudnn.rs
index 235ad6e3..dd466ba2 100644
--- a/candle-core/src/cudnn.rs
+++ b/candle-core/src/cudnn.rs
@@ -54,8 +54,8 @@ pub(crate) fn launch_conv2d<
let x_shape = [
params.b_size as i32,
params.c_in as i32,
- params.i_w as i32,
params.i_h as i32,
+ params.i_w as i32,
];
// Note that `src` already starts at the proper offset.
let x = if src_l.is_contiguous() {
@@ -75,14 +75,14 @@ pub(crate) fn launch_conv2d<
[
params.c_out as i32,
params.c_in as i32,
- params.k_w as i32,
params.k_h as i32,
+ params.k_w as i32,
],
)?;
let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32);
let y = cudnn.create_4d_tensor(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
- [params.b_size as i32, params.c_out as i32, w_out, h_out],
+ [params.b_size as i32, params.c_out as i32, h_out, w_out],
)?;
let conv2d = Conv2dForward {
conv: &conv,