diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-13 11:04:40 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-13 11:04:40 +0100 |
commit | 50b0946a2dff2a65f8319ff6d798f12b2ea2a6fb (patch) | |
tree | c48c4ecc686748e10b678d347af8d46cb0955a6c /candle-core/tests/grad_tests.rs | |
parent | a3663ce2f2b03263075099baed677340974b7f4c (diff) | |
download | candle-50b0946a2dff2a65f8319ff6d798f12b2ea2a6fb.tar.gz candle-50b0946a2dff2a65f8319ff6d798f12b2ea2a6fb.tar.bz2 candle-50b0946a2dff2a65f8319ff6d798f12b2ea2a6fb.zip |
Tensor mutability (#154)
* Working towards tensor mutability.
* Use a ref-cell to provide tensor mutability.
Diffstat (limited to 'candle-core/tests/grad_tests.rs')
-rw-r--r-- | candle-core/tests/grad_tests.rs | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs index 10eef780..d5c8f751 100644 --- a/candle-core/tests/grad_tests.rs +++ b/candle-core/tests/grad_tests.rs @@ -27,12 +27,18 @@ fn matmul_grad(device: &Device) -> Result<()> { assert_eq!(grad_x.shape(), &Shape::from((2, 2, 3))); assert_eq!(grad_y.shape(), &Shape::from((2, 3, 2))); assert_eq!( - &*grad_x.storage_data::<f32>()?, - &[1., 5., 9., 1., 5., 9., 13., 17., 21., 13., 17., 21.] + &*grad_x.to_vec3::<f32>()?, + &[ + [[1., 5., 9.], [1., 5., 9.]], + [[13., 17., 21.], [13., 17., 21.]] + ] ); assert_eq!( - &*grad_y.storage_data::<f32>()?, - &[3., 3., 5., 5., 7., 7., 15., 15., 17., 17., 19., 19.] + &*grad_y.to_vec3::<f32>()?, + &[ + [[3., 3.], [5., 5.], [7., 7.]], + [[15., 15.], [17., 17.], [19., 19.]] + ] ); Ok(()) } |