diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-09 16:44:16 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-09 15:44:16 +0100 |
commit | b80348d22f8f0dadb6cc4101bde031d5de69a9a5 (patch) | |
tree | c2a64e014f628425fc1d1627d2be103441f9309e | |
parent | 3a62aee91ffd4b73eb7811cf08094ab1910a5256 (diff) | |
download | candle-b80348d22f8f0dadb6cc4101bde031d5de69a9a5.tar.gz candle-b80348d22f8f0dadb6cc4101bde031d5de69a9a5.tar.bz2 candle-b80348d22f8f0dadb6cc4101bde031d5de69a9a5.zip |
Bugfix for avg-pool + add some test. (#365)
-rw-r--r-- | candle-core/src/cpu_backend.rs | 2 | ||||
-rw-r--r-- | candle-core/tests/pool_tests.rs | 17 |
2 files changed, 19 insertions, 0 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 0ec19559..10c6cc4a 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -660,6 +660,8 @@ impl Map1 for AvgPool2D { let mut sum = T::zero(); for m in 0..k_h { for n in 0..k_w { + let m = k_h * h_idx + m; + let n = k_w * w_idx + n; sum += src[src_index + m * stride_h + n * stride_w] } } diff --git a/candle-core/tests/pool_tests.rs b/candle-core/tests/pool_tests.rs new file mode 100644 index 00000000..574182ca --- /dev/null +++ b/candle-core/tests/pool_tests.rs @@ -0,0 +1,17 @@ +mod test_utils; +use candle_core::{Device, Tensor}; + +// https://github.com/huggingface/candle/issues/364 +#[test] +fn avg_pool2d() -> anyhow::Result<()> { + let device = Device::Cpu; + + let data: Vec<f32> = vec![ + 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., + ]; + let t = Tensor::from_vec(data, (1, 1, 4, 4), &device)?; + + let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?; + assert_eq!(pool.to_vec2::<f32>()?, [[0.5f32, 1.], [1., 1.]]); + Ok(()) +} |