diff options
Diffstat (limited to 'candle-nn/tests/optim.rs')
-rw-r--r-- | candle-nn/tests/optim.rs | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/candle-nn/tests/optim.rs b/candle-nn/tests/optim.rs index 8228e435..54c378cc 100644 --- a/candle-nn/tests/optim.rs +++ b/candle-nn/tests/optim.rs @@ -8,7 +8,7 @@ use candle_nn::{Linear, SGD}; #[test] fn sgd_optim() -> Result<()> { let x = Var::new(0f32, &Device::Cpu)?; - let sgd = SGD::new(&[&x], 0.1); + let sgd = SGD::new(vec![x.clone()], 0.1); let xt = x.as_tensor(); for _step in 0..100 { let loss = ((xt - 4.2)? * (xt - 4.2)?)?; @@ -54,7 +54,7 @@ fn sgd_linear_regression() -> Result<()> { // Now use backprop to run a linear regression between samples and get the coefficients back. let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; let b = Var::new(0f32, &Device::Cpu)?; - let sgd = SGD::new(&[&w, &b], 0.004); + let sgd = SGD::new(vec![w.clone(), b.clone()], 0.004); let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); for _step in 0..1000 { let ys = lin.forward(&sample_xs)?; |