diff options
Diffstat (limited to 'candle-core/tests/conv_tests.rs')
-rw-r--r-- | candle-core/tests/conv_tests.rs | 14 |
1 files changed, 7 insertions, 7 deletions
diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs index c777fec7..d09fa344 100644 --- a/candle-core/tests/conv_tests.rs +++ b/candle-core/tests/conv_tests.rs @@ -33,13 +33,13 @@ fn conv1d(dev: &Device) -> Result<()> { dev, )? .reshape((2, 4, 3))?; - let res = t.conv1d(&w, 0, 1)?; + let res = t.conv1d(&w, 0, 1, 1)?; assert_eq!(res.dims(), [1, 2, 3]); assert_eq!( test_utils::to_vec1_round(&res.flatten_all()?, 4)?, [2.6357, -1.3336, 4.1393, -1.1784, 3.5675, 0.5069] ); - let res = t.conv1d(&w, /*padding*/ 1, 1)?; + let res = t.conv1d(&w, /*padding*/ 1, 1, 1)?; assert_eq!(res.dims(), [1, 2, 5]); // Same as pytorch default padding: use zeros. assert_eq!( @@ -52,13 +52,13 @@ fn conv1d(dev: &Device) -> Result<()> { fn conv1d_small(dev: &Device) -> Result<()> { let t = Tensor::new(&[0.4056f32, -0.8689, -0.0773, -1.5630], dev)?.reshape((1, 1, 4))?; let w = Tensor::new(&[1f32, 0., 0.], dev)?.reshape((1, 1, 3))?; - let res = t.conv1d(&w, 0, 1)?; + let res = t.conv1d(&w, 0, 1, 1)?; assert_eq!(res.dims(), [1, 1, 2]); assert_eq!( test_utils::to_vec1_round(&res.flatten_all()?, 4)?, [0.4056, -0.8689] ); - let res = t.conv1d(&w, /*padding*/ 1, 1)?; + let res = t.conv1d(&w, /*padding*/ 1, 1, 1)?; assert_eq!(res.dims(), [1, 1, 4]); assert_eq!( test_utils::to_vec1_round(&res.flatten_all()?, 4)?, @@ -109,7 +109,7 @@ fn conv2d(dev: &Device) -> Result<()> { )?; let t = t.reshape((1, 4, 5, 5))?; let w = w.reshape((2, 4, 3, 3))?; - let res = t.conv2d(&w, 0, 1)?; + let res = t.conv2d(&w, 0, 1, 1)?; assert_eq!(res.dims(), [1, 2, 3, 3]); assert_eq!( test_utils::to_vec1_round(&res.flatten_all()?, 4)?, @@ -143,7 +143,7 @@ fn conv2d_small(dev: &Device) -> Result<()> { let w = Tensor::new(&[-0.9259f32, 1.3017], dev)?; let t = t.reshape((1, 2, 3, 3))?; let w = w.reshape((1, 2, 1, 1))?; - let res = t.conv2d(&w, 0, 1)?; + let res = t.conv2d(&w, 0, 1, 1)?; assert_eq!(res.dims(), [1, 1, 3, 3]); assert_eq!( test_utils::to_vec1_round(&res.flatten_all()?, 4)?, @@ -162,7 +162,7 @@ fn conv2d_smaller(dev: &Device) -> Result<()> { let w = Tensor::new(&[1f32, 1., 1., 1., 1., 1., 1., 1., 1.], dev)?; let t = t.reshape((1, 1, 3, 3))?; let w = w.reshape((1, 1, 3, 3))?; - let res = t.conv2d(&w, 0, 1)?; + let res = t.conv2d(&w, 0, 1, 1)?; assert_eq!(res.dims(), [1, 1, 1, 1]); assert_eq!( test_utils::to_vec1_round(&res.flatten_all()?, 4)?, |