diff options
Diffstat (limited to 'candle-nn/src/optim.rs')
-rw-r--r-- | candle-nn/src/optim.rs | 9 |
1 files changed, 8 insertions, 1 deletions
diff --git a/candle-nn/src/optim.rs b/candle-nn/src/optim.rs index a8b5b370..d20ef284 100644 --- a/candle-nn/src/optim.rs +++ b/candle-nn/src/optim.rs @@ -8,7 +8,7 @@ pub struct SGD { } impl SGD { - pub fn new(vars: &[&Var], learning_rate: f64) -> Self { + pub fn from_slice(vars: &[&Var], learning_rate: f64) -> Self { let vars: Vec<_> = vars.iter().map(|&v| v.clone()).collect(); Self { vars, @@ -16,6 +16,13 @@ impl SGD { } } + pub fn new(vars: Vec<Var>, learning_rate: f64) -> Self { + Self { + vars, + learning_rate, + } + } + pub fn empty(learning_rate: f64) -> Self { Self { vars: vec![], |