diff options
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/src/optim.rs | 2 | ||||
-rw-r--r-- | candle-nn/tests/optim.rs | 28 |
2 files changed, 27 insertions, 3 deletions
diff --git a/candle-nn/src/optim.rs b/candle-nn/src/optim.rs index 741c51dc..a8b5b370 100644 --- a/candle-nn/src/optim.rs +++ b/candle-nn/src/optim.rs @@ -39,7 +39,7 @@ impl SGD { let grads = loss.backward()?; for var in self.vars.iter() { if let Some(grad) = grads.get(var) { - var.set(&var.sub(&(grad * self.learning_rate)?)?)? + var.set(&var.sub(&(grad * self.learning_rate)?)?)?; } } Ok(()) diff --git a/candle-nn/tests/optim.rs b/candle-nn/tests/optim.rs index 29aa987b..2ba6adbb 100644 --- a/candle-nn/tests/optim.rs +++ b/candle-nn/tests/optim.rs @@ -2,8 +2,8 @@ extern crate intel_mkl_src; use anyhow::Result; -use candle::{Device, Var}; -use candle_nn::SGD; +use candle::{Device, Tensor, Var}; +use candle_nn::{Linear, SGD}; #[test] fn sgd_optim() -> Result<()> { @@ -17,3 +17,27 @@ fn sgd_optim() -> Result<()> { assert_eq!(x.to_scalar::<f32>()?, 4.199999); Ok(()) } + +#[test] +fn sgd_linear_regression() -> Result<()> { + // Generate some linear data, y = 3.x1 + x2 - 2. + let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; + let b_gen = Tensor::new(-2f32, &Device::Cpu)?; + let gen = Linear::new(w_gen, Some(b_gen)); + let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; + let sample_ys = gen.forward(&sample_xs)?; + + // Now use backprop to run a linear regression between samples and get the coefficients back. + let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; + let b = Var::new(0f32, &Device::Cpu)?; + let sgd = SGD::new(&[&w, &b], 0.004); + let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); + for _step in 0..1000 { + let ys = lin.forward(&sample_xs)?; + let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; + sgd.backward_step(&loss)?; + } + assert_eq!(w.to_vec2::<f32>()?, &[[2.9983196, 0.99790204]]); + assert_eq!(b.to_scalar::<f32>()?, -1.9796902); + Ok(()) +} |