diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/grad_tests.rs | 26 |
1 files changed, 25 insertions, 1 deletions
diff --git a/tests/grad_tests.rs b/tests/grad_tests.rs index 56186e5d..77a32dfe 100644 --- a/tests/grad_tests.rs +++ b/tests/grad_tests.rs @@ -1,5 +1,5 @@ use anyhow::{Context, Result}; -use candle::{Device, Tensor}; +use candle::{Device, Shape, Tensor}; #[test] fn simple_grad() -> Result<()> { @@ -14,3 +14,27 @@ fn simple_grad() -> Result<()> { assert_eq!(grad_x.to_vec1::<f32>()?, [11., 7., 13.]); Ok(()) } + +#[test] +fn matmul_grad() -> Result<()> { + let data: Vec<_> = (0..12).map(|i| i as f32).collect(); + let x = Tensor::var_from_slice(&data, (2, 2, 3), &Device::Cpu)?; + let data: Vec<_> = (0..12).map(|i| i as f32).collect(); + let y = Tensor::var_from_slice(&data, (2, 3, 2), &Device::Cpu)?; + + let c = x.matmul(&y)?; + let grads = c.backward()?; + let grad_x = grads.get(&x).context("no grad for x")?; + let grad_y = grads.get(&y).context("no grad for y")?; + assert_eq!(grad_x.shape(), &Shape::from((2, 2, 3))); + assert_eq!(grad_y.shape(), &Shape::from((2, 3, 2))); + assert_eq!( + grad_x.as_slice::<f32>()?, + &[1., 5., 9., 1., 5., 9., 13., 17., 21., 13., 17., 21.] + ); + assert_eq!( + grad_y.as_slice::<f32>()?, + &[3., 3., 5., 5., 7., 7., 15., 15., 17., 17., 19., 19.] + ); + Ok(()) +} |