summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-13 20:14:10 +0100
committerGitHub <noreply@github.com>2023-07-13 20:14:10 +0100
commit23e105cd941bdd1aaa1ae45816a9910fb829e74f (patch)
tree85bf768748676f0cc0a223845d691279b544afb0
parent3c02ea56b0dcc39c30ad0d41d942384cc28f65c2 (diff)
downloadcandle-23e105cd941bdd1aaa1ae45816a9910fb829e74f.tar.gz
candle-23e105cd941bdd1aaa1ae45816a9910fb829e74f.tar.bz2
candle-23e105cd941bdd1aaa1ae45816a9910fb829e74f.zip
Add the gradient for reduce-sum. (#162)
* Add the gradient for reduce-sum. * And add the gradient for the broadcast ops. * Add some backprop tests. * Add some linear regression example.
-rw-r--r--candle-core/src/backprop.rs30
-rw-r--r--candle-core/tests/grad_tests.rs21
-rw-r--r--candle-nn/src/optim.rs2
-rw-r--r--candle-nn/tests/optim.rs28
4 files changed, 74 insertions, 7 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index 61a81be0..2711da85 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -179,11 +179,33 @@ impl Tensor {
start_idx += len;
}
}
- Op::Broadcast(_arg) => {
- return Err(Error::BackwardNotSupported { op: "broadcast" })
+ Op::Broadcast(arg) => {
+ let arg_dims = arg.dims();
+ let node_dims = node.dims();
+ // The number of dims that have been inserted on the left.
+ let left_dims = node_dims.len() - arg_dims.len();
+ let mut sum_dims: Vec<usize> = (0..left_dims).collect();
+ for (dim, (node_dim, arg_dim)) in node_dims[left_dims..]
+ .iter()
+ .zip(arg_dims.iter())
+ .enumerate()
+ {
+ if node_dim != arg_dim {
+ sum_dims.push(dim + left_dims)
+ }
+ }
+
+ let mut arg_grad = grad.sum(sum_dims.as_slice())?;
+ // sum_dims has increasing values.
+ for &dim in sum_dims.iter().rev() {
+ arg_grad = arg_grad.squeeze(dim)?
+ }
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.broadcast_add(&arg_grad)?
}
- Op::Sum(_arg, _sum_dims) => {
- return Err(Error::BackwardNotSupported { op: "sum" })
+ Op::Sum(arg, _sum_dims) => {
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.broadcast_add(&grad)?
}
Op::ToDType(arg) => {
let sum_grad = grads.or_insert(arg)?;
diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs
index 6f11879f..501b0d44 100644
--- a/candle-core/tests/grad_tests.rs
+++ b/candle-core/tests/grad_tests.rs
@@ -16,6 +16,26 @@ fn simple_grad(device: &Device) -> Result<()> {
Ok(())
}
+fn sum_grad(device: &Device) -> Result<()> {
+ let x = Var::new(&[3f32, 1., 4.], device)?;
+ let x = x.as_tensor();
+ let y = (x.sqr()?.sum(&[0])? * 2.)?;
+ let grads = y.backward()?;
+ let grad_x = grads.get(x).context("no grad for x")?;
+ assert_eq!(y.to_vec1::<f32>()?, [52.]);
+ // y = 2.x^2 so dy/dx = 4.x
+ assert_eq!(grad_x.to_vec1::<f32>()?, &[12., 4., 16.]);
+
+ // Same test as before but squeezing on the last dimension.
+ let y = (x.sqr()?.sum(&[0])? * 2.)?.squeeze(0)?;
+ let grads = y.backward()?;
+ let grad_x = grads.get(x).context("no grad for x")?;
+ assert_eq!(y.to_scalar::<f32>()?, 52.);
+ // y = 2.x^2 so dy/dx = 4.x
+ assert_eq!(grad_x.to_vec1::<f32>()?, &[12., 4., 16.]);
+ Ok(())
+}
+
fn matmul_grad(device: &Device) -> Result<()> {
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
let x = Var::from_slice(&data, (2, 2, 3), device)?;
@@ -60,5 +80,6 @@ fn grad_descent(device: &Device) -> Result<()> {
}
test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu);
+test_device!(sum_grad, sum_grad_cpu, sum_grad_gpu);
test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu);
test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu);
diff --git a/candle-nn/src/optim.rs b/candle-nn/src/optim.rs
index 741c51dc..a8b5b370 100644
--- a/candle-nn/src/optim.rs
+++ b/candle-nn/src/optim.rs
@@ -39,7 +39,7 @@ impl SGD {
let grads = loss.backward()?;
for var in self.vars.iter() {
if let Some(grad) = grads.get(var) {
- var.set(&var.sub(&(grad * self.learning_rate)?)?)?
+ var.set(&var.sub(&(grad * self.learning_rate)?)?)?;
}
}
Ok(())
diff --git a/candle-nn/tests/optim.rs b/candle-nn/tests/optim.rs
index 29aa987b..2ba6adbb 100644
--- a/candle-nn/tests/optim.rs
+++ b/candle-nn/tests/optim.rs
@@ -2,8 +2,8 @@
extern crate intel_mkl_src;
use anyhow::Result;
-use candle::{Device, Var};
-use candle_nn::SGD;
+use candle::{Device, Tensor, Var};
+use candle_nn::{Linear, SGD};
#[test]
fn sgd_optim() -> Result<()> {
@@ -17,3 +17,27 @@ fn sgd_optim() -> Result<()> {
assert_eq!(x.to_scalar::<f32>()?, 4.199999);
Ok(())
}
+
+#[test]
+fn sgd_linear_regression() -> Result<()> {
+ // Generate some linear data, y = 3.x1 + x2 - 2.
+ let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
+ let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
+ let gen = Linear::new(w_gen, Some(b_gen));
+ let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
+ let sample_ys = gen.forward(&sample_xs)?;
+
+ // Now use backprop to run a linear regression between samples and get the coefficients back.
+ let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
+ let b = Var::new(0f32, &Device::Cpu)?;
+ let sgd = SGD::new(&[&w, &b], 0.004);
+ let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
+ for _step in 0..1000 {
+ let ys = lin.forward(&sample_xs)?;
+ let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
+ sgd.backward_step(&loss)?;
+ }
+ assert_eq!(w.to_vec2::<f32>()?, &[[2.9983196, 0.99790204]]);
+ assert_eq!(b.to_scalar::<f32>()?, -1.9796902);
+ Ok(())
+}