diff options
Diffstat (limited to 'candle-core/src/backprop.rs')
-rw-r--r-- | candle-core/src/backprop.rs | 15 |
1 files changed, 11 insertions, 4 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 9ecdee4f..f4f90373 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -197,21 +197,28 @@ impl Tensor { kernel, padding, stride, + dilation, } => { // The output height for conv_transpose2d is: // (i_h - 1) * stride - 2 * padding + dilation * (k_h - 1) + out_padding + 1 let grad_h = grad.dim(2)?; let k_h = kernel.dim(2)?; - let out_size = (grad_h - 1) * stride + (k_h - 1) + 1 - 2 * padding; + let out_size = + (grad_h - 1) * stride + dilation * (k_h - 1) + 1 - 2 * padding; let out_padding = arg.dim(2)? - out_size; - let grad_arg = - grad.conv_transpose2d(kernel, *padding, out_padding, *stride)?; + let grad_arg = grad.conv_transpose2d( + kernel, + *padding, + out_padding, + *stride, + *dilation, + )?; let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&grad_arg)?; let grad_kernel = arg .transpose(0, 1)? - .conv2d(&grad.transpose(0, 1)?, *padding, *stride, 1)? + .conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)? .transpose(0, 1)?; let sum_grad = grads.or_insert(kernel)?; *sum_grad = sum_grad.add(&grad_kernel)?; |