diff options
Diffstat (limited to 'candle-core/tests/grad_tests.rs')
-rw-r--r-- | candle-core/tests/grad_tests.rs | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs index 6f11879f..501b0d44 100644 --- a/candle-core/tests/grad_tests.rs +++ b/candle-core/tests/grad_tests.rs @@ -16,6 +16,26 @@ fn simple_grad(device: &Device) -> Result<()> { Ok(()) } +fn sum_grad(device: &Device) -> Result<()> { + let x = Var::new(&[3f32, 1., 4.], device)?; + let x = x.as_tensor(); + let y = (x.sqr()?.sum(&[0])? * 2.)?; + let grads = y.backward()?; + let grad_x = grads.get(x).context("no grad for x")?; + assert_eq!(y.to_vec1::<f32>()?, [52.]); + // y = 2.x^2 so dy/dx = 4.x + assert_eq!(grad_x.to_vec1::<f32>()?, &[12., 4., 16.]); + + // Same test as before but squeezing on the last dimension. + let y = (x.sqr()?.sum(&[0])? * 2.)?.squeeze(0)?; + let grads = y.backward()?; + let grad_x = grads.get(x).context("no grad for x")?; + assert_eq!(y.to_scalar::<f32>()?, 52.); + // y = 2.x^2 so dy/dx = 4.x + assert_eq!(grad_x.to_vec1::<f32>()?, &[12., 4., 16.]); + Ok(()) +} + fn matmul_grad(device: &Device) -> Result<()> { let data: Vec<_> = (0..12).map(|i| i as f32).collect(); let x = Var::from_slice(&data, (2, 2, 3), device)?; @@ -60,5 +80,6 @@ fn grad_descent(device: &Device) -> Result<()> { } test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu); +test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu); test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu); test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu); |