diff options
Diffstat (limited to 'candle-core/tests/pool_tests.rs')
-rw-r--r-- | candle-core/tests/pool_tests.rs | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/candle-core/tests/pool_tests.rs b/candle-core/tests/pool_tests.rs index 009564fa..d2eb8f3f 100644 --- a/candle-core/tests/pool_tests.rs +++ b/candle-core/tests/pool_tests.rs @@ -56,9 +56,8 @@ fn avg_pool2d_pytorch(dev: &Device) -> Result<()> { Ok(()) } -#[test] -fn upsample_nearest2d() -> Result<()> { - let t = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((1, 1, 2, 3))?; +fn upsample_nearest2d(dev: &Device) -> Result<()> { + let t = Tensor::arange(0f32, 6f32, dev)?.reshape((1, 1, 2, 3))?; let upsampled = t.upsample_nearest2d(4, 6)?.i(0)?.i(0)?; assert_eq!( t.i(0)?.i(0)?.to_vec2::<f32>()?, @@ -83,3 +82,8 @@ test_device!( avg_pool2d_pytorch_gpu ); test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu); +test_device!( + upsample_nearest2d, + upsample_nearest2d_cpu, + upsample_nearest2d_gpu +); |