diff options
Diffstat (limited to 'candle-core/tests/grad_tests.rs')
-rw-r--r-- | candle-core/tests/grad_tests.rs | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs index 10eef780..d5c8f751 100644 --- a/candle-core/tests/grad_tests.rs +++ b/candle-core/tests/grad_tests.rs @@ -27,12 +27,18 @@ fn matmul_grad(device: &Device) -> Result<()> { assert_eq!(grad_x.shape(), &Shape::from((2, 2, 3))); assert_eq!(grad_y.shape(), &Shape::from((2, 3, 2))); assert_eq!( - &*grad_x.storage_data::<f32>()?, - &[1., 5., 9., 1., 5., 9., 13., 17., 21., 13., 17., 21.] + &*grad_x.to_vec3::<f32>()?, + &[ + [[1., 5., 9.], [1., 5., 9.]], + [[13., 17., 21.], [13., 17., 21.]] + ] ); assert_eq!( - &*grad_y.storage_data::<f32>()?, - &[3., 3., 5., 5., 7., 7., 15., 15., 17., 17., 19., 19.] + &*grad_y.to_vec3::<f32>()?, + &[ + [[3., 3.], [5., 5.], [7., 7.]], + [[15., 15.], [17., 17.], [19., 19.]] + ] ); Ok(()) } |