summaryrefslogtreecommitdiff
path: root/candle-nn/tests/optim.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/tests/optim.rs')
-rw-r--r--candle-nn/tests/optim.rs28
1 files changed, 26 insertions, 2 deletions
diff --git a/candle-nn/tests/optim.rs b/candle-nn/tests/optim.rs
index 29aa987b..2ba6adbb 100644
--- a/candle-nn/tests/optim.rs
+++ b/candle-nn/tests/optim.rs
@@ -2,8 +2,8 @@
extern crate intel_mkl_src;
use anyhow::Result;
-use candle::{Device, Var};
-use candle_nn::SGD;
+use candle::{Device, Tensor, Var};
+use candle_nn::{Linear, SGD};
#[test]
fn sgd_optim() -> Result<()> {
@@ -17,3 +17,27 @@ fn sgd_optim() -> Result<()> {
assert_eq!(x.to_scalar::<f32>()?, 4.199999);
Ok(())
}
+
+#[test]
+fn sgd_linear_regression() -> Result<()> {
+ // Generate some linear data, y = 3.x1 + x2 - 2.
+ let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
+ let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
+ let gen = Linear::new(w_gen, Some(b_gen));
+ let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
+ let sample_ys = gen.forward(&sample_xs)?;
+
+ // Now use backprop to run a linear regression between samples and get the coefficients back.
+ let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
+ let b = Var::new(0f32, &Device::Cpu)?;
+ let sgd = SGD::new(&[&w, &b], 0.004);
+ let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
+ for _step in 0..1000 {
+ let ys = lin.forward(&sample_xs)?;
+ let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
+ sgd.backward_step(&loss)?;
+ }
+ assert_eq!(w.to_vec2::<f32>()?, &[[2.9983196, 0.99790204]]);
+ assert_eq!(b.to_scalar::<f32>()?, -1.9796902);
+ Ok(())
+}