summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-09 16:44:16 +0200
committerGitHub <noreply@github.com>2023-08-09 15:44:16 +0100
commitb80348d22f8f0dadb6cc4101bde031d5de69a9a5 (patch)
treec2a64e014f628425fc1d1627d2be103441f9309e
parent3a62aee91ffd4b73eb7811cf08094ab1910a5256 (diff)
downloadcandle-b80348d22f8f0dadb6cc4101bde031d5de69a9a5.tar.gz
candle-b80348d22f8f0dadb6cc4101bde031d5de69a9a5.tar.bz2
candle-b80348d22f8f0dadb6cc4101bde031d5de69a9a5.zip
Bugfix for avg-pool + add some test. (#365)
-rw-r--r--candle-core/src/cpu_backend.rs2
-rw-r--r--candle-core/tests/pool_tests.rs17
2 files changed, 19 insertions, 0 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 0ec19559..10c6cc4a 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -660,6 +660,8 @@ impl Map1 for AvgPool2D {
let mut sum = T::zero();
for m in 0..k_h {
for n in 0..k_w {
+ let m = k_h * h_idx + m;
+ let n = k_w * w_idx + n;
sum += src[src_index + m * stride_h + n * stride_w]
}
}
diff --git a/candle-core/tests/pool_tests.rs b/candle-core/tests/pool_tests.rs
new file mode 100644
index 00000000..574182ca
--- /dev/null
+++ b/candle-core/tests/pool_tests.rs
@@ -0,0 +1,17 @@
+mod test_utils;
+use candle_core::{Device, Tensor};
+
+// https://github.com/huggingface/candle/issues/364
+#[test]
+fn avg_pool2d() -> anyhow::Result<()> {
+ let device = Device::Cpu;
+
+ let data: Vec<f32> = vec![
+ 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
+ ];
+ let t = Tensor::from_vec(data, (1, 1, 4, 4), &device)?;
+
+ let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
+ assert_eq!(pool.to_vec2::<f32>()?, [[0.5f32, 1.], [1., 1.]]);
+ Ok(())
+}