summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/backprop.rs7
1 files changed, 7 insertions, 0 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index 7488d939..155f49c5 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -238,6 +238,13 @@ impl Tensor {
.conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
.transpose(0, 1)?;
let sum_grad = grads.or_insert(kernel)?;
+ let (_, _, k0, k1) = kernel.dims4()?;
+ let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
+ let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
+ grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
+ } else {
+ grad_kernel
+ };
*sum_grad = sum_grad.add(&grad_kernel)?;
}
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {