summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-29 10:50:04 +0100
committerGitHub <noreply@github.com>2023-10-29 09:50:04 +0000
commit46d6566c99f63fc74f3fbf5754183a49219224d5 (patch)
tree50b45cc125bb9b7cd6c5b03a31e966607c62e848 /candle-core/src
parent55bc3382cfd3a86018c54f2343567f7c0c0b677c (diff)
downloadcandle-46d6566c99f63fc74f3fbf5754183a49219224d5.tar.gz
candle-46d6566c99f63fc74f3fbf5754183a49219224d5.tar.bz2
candle-46d6566c99f63fc74f3fbf5754183a49219224d5.zip
Fix the conv2d gradient computation. (#1214)
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/backprop.rs7
1 files changed, 7 insertions, 0 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index 7488d939..155f49c5 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -238,6 +238,13 @@ impl Tensor {
.conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
.transpose(0, 1)?;
let sum_grad = grads.or_insert(kernel)?;
+ let (_, _, k0, k1) = kernel.dims4()?;
+ let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
+ let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
+ grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
+ } else {
+ grad_kernel
+ };
*sum_grad = sum_grad.add(&grad_kernel)?;
}
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {