diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-29 10:50:04 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-29 09:50:04 +0000 |
commit | 46d6566c99f63fc74f3fbf5754183a49219224d5 (patch) | |
tree | 50b45cc125bb9b7cd6c5b03a31e966607c62e848 /candle-core/src | |
parent | 55bc3382cfd3a86018c54f2343567f7c0c0b677c (diff) | |
download | candle-46d6566c99f63fc74f3fbf5754183a49219224d5.tar.gz candle-46d6566c99f63fc74f3fbf5754183a49219224d5.tar.bz2 candle-46d6566c99f63fc74f3fbf5754183a49219224d5.zip |
Fix the conv2d gradient computation. (#1214)
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/backprop.rs | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 7488d939..155f49c5 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -238,6 +238,13 @@ impl Tensor { .conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)? .transpose(0, 1)?; let sum_grad = grads.or_insert(kernel)?; + let (_, _, k0, k1) = kernel.dims4()?; + let (_, _, g_k0, g_k1) = grad_kernel.dims4()?; + let grad_kernel = if g_k0 != k0 || g_k1 != k1 { + grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)? + } else { + grad_kernel + }; *sum_grad = sum_grad.add(&grad_kernel)?; } Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported { |