summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/backprop.rs7
-rw-r--r--candle-core/tests/conv_tests.rs65
2 files changed, 72 insertions, 0 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index 7488d939..155f49c5 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -238,6 +238,13 @@ impl Tensor {
.conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
.transpose(0, 1)?;
let sum_grad = grads.or_insert(kernel)?;
+ let (_, _, k0, k1) = kernel.dims4()?;
+ let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
+ let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
+ grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
+ } else {
+ grad_kernel
+ };
*sum_grad = sum_grad.add(&grad_kernel)?;
}
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs
index 937ddf67..e7fdf138 100644
--- a/candle-core/tests/conv_tests.rs
+++ b/candle-core/tests/conv_tests.rs
@@ -479,6 +479,71 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
]
]
);
+
+ // Replicate the issue from https://github.com/huggingface/candle/issues/1212
+ let res = t.i((.., .., 0..4, 0..4))?.conv2d(&w, 0, 2, 1, 1)?;
+ let loss = res.sqr()?.sum_all()?;
+ assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 21.12f32);
+ let grads = loss.backward()?;
+ let grad_t = grads.get(&t).unwrap();
+ let grad_w = grads.get(&w).unwrap();
+ assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
+ assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
+ assert_eq!(
+ test_utils::to_vec3_round(&grad_t.i(0)?, 2)?,
+ [
+ [
+ [9.29, -7.03, 7.87, 0.0, 0.0],
+ [-1.8, -7.82, 5.9, 0.0, 0.0],
+ [-3.12, 4.49, 5.52, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0, 0.0]
+ ],
+ [
+ [21.73, 3.39, 4.77, 0.0, 0.0],
+ [8.25, 3.73, 27.61, 0.0, 0.0],
+ [-20.55, -5.61, -2.77, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0, 0.0]
+ ],
+ [
+ [-8.98, 9.91, -7.15, 0.0, 0.0],
+ [4.93, -0.33, 4.56, 0.0, 0.0],
+ [-6.7, -5.76, -8.05, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0, 0.0]
+ ],
+ [
+ [23.54, 6.98, -10.0, 0.0, 0.0],
+ [9.65, 6.18, 18.72, 0.0, 0.0],
+ [3.29, -5.27, 0.79, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0, 0.0]
+ ]
+ ]
+ );
+ assert_eq!(
+ test_utils::to_vec3_round(&grad_w.i(0)?, 2)?,
+ [
+ [
+ [-3.47, 7.44, 0.66],
+ [12.89, -3.4, -9.29],
+ [-14.16, -0.83, 7.14]
+ ],
+ [
+ [-3.23, 5.37, -3.02],
+ [-2.12, -11.24, 1.94],
+ [6.97, 7.2, 2.99]
+ ],
+ [
+ [-4.04, -3.31, 4.87],
+ [-6.68, -5.68, 1.73],
+ [-5.54, 4.32, 0.52]
+ ],
+ [[-4.72, 1.5, 4.72], [3.79, 4.04, 6.76], [-4.6, 5.8, 6.93]]
+ ]
+ );
+
Ok(())
}