diff options
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/backprop.rs | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 7488d939..155f49c5 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -238,6 +238,13 @@ impl Tensor { .conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)? .transpose(0, 1)?; let sum_grad = grads.or_insert(kernel)?; + let (_, _, k0, k1) = kernel.dims4()?; + let (_, _, g_k0, g_k1) = grad_kernel.dims4()?; + let grad_kernel = if g_k0 != k0 || g_k1 != k1 { + grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)? + } else { + grad_kernel + }; *sum_grad = sum_grad.add(&grad_kernel)?; } Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported { |