summaryrefslogtreecommitdiff
path: root/candle-core/tests/grad_tests.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/tests/grad_tests.rs')
-rw-r--r--candle-core/tests/grad_tests.rs21
1 files changed, 21 insertions, 0 deletions
diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs
index 6f11879f..501b0d44 100644
--- a/candle-core/tests/grad_tests.rs
+++ b/candle-core/tests/grad_tests.rs
@@ -16,6 +16,26 @@ fn simple_grad(device: &Device) -> Result<()> {
Ok(())
}
+fn sum_grad(device: &Device) -> Result<()> {
+ let x = Var::new(&[3f32, 1., 4.], device)?;
+ let x = x.as_tensor();
+ let y = (x.sqr()?.sum(&[0])? * 2.)?;
+ let grads = y.backward()?;
+ let grad_x = grads.get(x).context("no grad for x")?;
+ assert_eq!(y.to_vec1::<f32>()?, [52.]);
+ // y = 2.x^2 so dy/dx = 4.x
+ assert_eq!(grad_x.to_vec1::<f32>()?, &[12., 4., 16.]);
+
+ // Same test as before but squeezing on the last dimension.
+ let y = (x.sqr()?.sum(&[0])? * 2.)?.squeeze(0)?;
+ let grads = y.backward()?;
+ let grad_x = grads.get(x).context("no grad for x")?;
+ assert_eq!(y.to_scalar::<f32>()?, 52.);
+ // y = 2.x^2 so dy/dx = 4.x
+ assert_eq!(grad_x.to_vec1::<f32>()?, &[12., 4., 16.]);
+ Ok(())
+}
+
fn matmul_grad(device: &Device) -> Result<()> {
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
let x = Var::from_slice(&data, (2, 2, 3), device)?;
@@ -60,5 +80,6 @@ fn grad_descent(device: &Device) -> Result<()> {
}
test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu);
+test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu);
test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu);
test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu);