summaryrefslogtreecommitdiff
path: root/candle-core/src/backprop.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/backprop.rs')
-rw-r--r--candle-core/src/backprop.rs15
1 files changed, 11 insertions, 4 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index 9ecdee4f..f4f90373 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -197,21 +197,28 @@ impl Tensor {
kernel,
padding,
stride,
+ dilation,
} => {
// The output height for conv_transpose2d is:
// (i_h - 1) * stride - 2 * padding + dilation * (k_h - 1) + out_padding + 1
let grad_h = grad.dim(2)?;
let k_h = kernel.dim(2)?;
- let out_size = (grad_h - 1) * stride + (k_h - 1) + 1 - 2 * padding;
+ let out_size =
+ (grad_h - 1) * stride + dilation * (k_h - 1) + 1 - 2 * padding;
let out_padding = arg.dim(2)? - out_size;
- let grad_arg =
- grad.conv_transpose2d(kernel, *padding, out_padding, *stride)?;
+ let grad_arg = grad.conv_transpose2d(
+ kernel,
+ *padding,
+ out_padding,
+ *stride,
+ *dilation,
+ )?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
let grad_kernel = arg
.transpose(0, 1)?
- .conv2d(&grad.transpose(0, 1)?, *padding, *stride, 1)?
+ .conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
.transpose(0, 1)?;
let sum_grad = grads.or_insert(kernel)?;
*sum_grad = sum_grad.add(&grad_kernel)?;