summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/grad_tests.rs26
1 files changed, 25 insertions, 1 deletions
diff --git a/tests/grad_tests.rs b/tests/grad_tests.rs
index 56186e5d..77a32dfe 100644
--- a/tests/grad_tests.rs
+++ b/tests/grad_tests.rs
@@ -1,5 +1,5 @@
use anyhow::{Context, Result};
-use candle::{Device, Tensor};
+use candle::{Device, Shape, Tensor};
#[test]
fn simple_grad() -> Result<()> {
@@ -14,3 +14,27 @@ fn simple_grad() -> Result<()> {
assert_eq!(grad_x.to_vec1::<f32>()?, [11., 7., 13.]);
Ok(())
}
+
+#[test]
+fn matmul_grad() -> Result<()> {
+ let data: Vec<_> = (0..12).map(|i| i as f32).collect();
+ let x = Tensor::var_from_slice(&data, (2, 2, 3), &Device::Cpu)?;
+ let data: Vec<_> = (0..12).map(|i| i as f32).collect();
+ let y = Tensor::var_from_slice(&data, (2, 3, 2), &Device::Cpu)?;
+
+ let c = x.matmul(&y)?;
+ let grads = c.backward()?;
+ let grad_x = grads.get(&x).context("no grad for x")?;
+ let grad_y = grads.get(&y).context("no grad for y")?;
+ assert_eq!(grad_x.shape(), &Shape::from((2, 2, 3)));
+ assert_eq!(grad_y.shape(), &Shape::from((2, 3, 2)));
+ assert_eq!(
+ grad_x.as_slice::<f32>()?,
+ &[1., 5., 9., 1., 5., 9., 13., 17., 21., 13., 17., 21.]
+ );
+ assert_eq!(
+ grad_y.as_slice::<f32>()?,
+ &[3., 3., 5., 5., 7., 7., 15., 15., 17., 17., 19., 19.]
+ );
+ Ok(())
+}