summaryrefslogtreecommitdiff
path: root/candle-core/tests/pool_tests.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/tests/pool_tests.rs')
-rw-r--r--candle-core/tests/pool_tests.rs10
1 files changed, 7 insertions, 3 deletions
diff --git a/candle-core/tests/pool_tests.rs b/candle-core/tests/pool_tests.rs
index 009564fa..d2eb8f3f 100644
--- a/candle-core/tests/pool_tests.rs
+++ b/candle-core/tests/pool_tests.rs
@@ -56,9 +56,8 @@ fn avg_pool2d_pytorch(dev: &Device) -> Result<()> {
Ok(())
}
-#[test]
-fn upsample_nearest2d() -> Result<()> {
- let t = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((1, 1, 2, 3))?;
+fn upsample_nearest2d(dev: &Device) -> Result<()> {
+ let t = Tensor::arange(0f32, 6f32, dev)?.reshape((1, 1, 2, 3))?;
let upsampled = t.upsample_nearest2d(4, 6)?.i(0)?.i(0)?;
assert_eq!(
t.i(0)?.i(0)?.to_vec2::<f32>()?,
@@ -83,3 +82,8 @@ test_device!(
avg_pool2d_pytorch_gpu
);
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu);
+test_device!(
+ upsample_nearest2d,
+ upsample_nearest2d_cpu,
+ upsample_nearest2d_gpu
+);