summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/cpu_backend.rs2
-rw-r--r--candle-core/tests/pool_tests.rs17
2 files changed, 19 insertions, 0 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 0ec19559..10c6cc4a 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -660,6 +660,8 @@ impl Map1 for AvgPool2D {
let mut sum = T::zero();
for m in 0..k_h {
for n in 0..k_w {
+ let m = k_h * h_idx + m;
+ let n = k_w * w_idx + n;
sum += src[src_index + m * stride_h + n * stride_w]
}
}
diff --git a/candle-core/tests/pool_tests.rs b/candle-core/tests/pool_tests.rs
new file mode 100644
index 00000000..574182ca
--- /dev/null
+++ b/candle-core/tests/pool_tests.rs
@@ -0,0 +1,17 @@
+mod test_utils;
+use candle_core::{Device, Tensor};
+
+// https://github.com/huggingface/candle/issues/364
+#[test]
+fn avg_pool2d() -> anyhow::Result<()> {
+ let device = Device::Cpu;
+
+ let data: Vec<f32> = vec![
+ 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
+ ];
+ let t = Tensor::from_vec(data, (1, 1, 4, 4), &device)?;
+
+ let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
+ assert_eq!(pool.to_vec2::<f32>()?, [[0.5f32, 1.], [1., 1.]]);
+ Ok(())
+}