summaryrefslogtreecommitdiff
path: root/candle-nn/tests/optim.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/tests/optim.rs')
-rw-r--r--candle-nn/tests/optim.rs4
1 files changed, 2 insertions, 2 deletions
diff --git a/candle-nn/tests/optim.rs b/candle-nn/tests/optim.rs
index 8228e435..54c378cc 100644
--- a/candle-nn/tests/optim.rs
+++ b/candle-nn/tests/optim.rs
@@ -8,7 +8,7 @@ use candle_nn::{Linear, SGD};
#[test]
fn sgd_optim() -> Result<()> {
let x = Var::new(0f32, &Device::Cpu)?;
- let sgd = SGD::new(&[&x], 0.1);
+ let sgd = SGD::new(vec![x.clone()], 0.1);
let xt = x.as_tensor();
for _step in 0..100 {
let loss = ((xt - 4.2)? * (xt - 4.2)?)?;
@@ -54,7 +54,7 @@ fn sgd_linear_regression() -> Result<()> {
// Now use backprop to run a linear regression between samples and get the coefficients back.
let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
let b = Var::new(0f32, &Device::Cpu)?;
- let sgd = SGD::new(&[&w, &b], 0.004);
+ let sgd = SGD::new(vec![w.clone(), b.clone()], 0.004);
let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
for _step in 0..1000 {
let ys = lin.forward(&sample_xs)?;