summaryrefslogtreecommitdiff
path: root/candle-nn/src/optim.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/src/optim.rs')
-rw-r--r--candle-nn/src/optim.rs9
1 files changed, 8 insertions, 1 deletions
diff --git a/candle-nn/src/optim.rs b/candle-nn/src/optim.rs
index a8b5b370..d20ef284 100644
--- a/candle-nn/src/optim.rs
+++ b/candle-nn/src/optim.rs
@@ -8,7 +8,7 @@ pub struct SGD {
}
impl SGD {
- pub fn new(vars: &[&Var], learning_rate: f64) -> Self {
+ pub fn from_slice(vars: &[&Var], learning_rate: f64) -> Self {
let vars: Vec<_> = vars.iter().map(|&v| v.clone()).collect();
Self {
vars,
@@ -16,6 +16,13 @@ impl SGD {
}
}
+ pub fn new(vars: Vec<Var>, learning_rate: f64) -> Self {
+ Self {
+ vars,
+ learning_rate,
+ }
+ }
+
pub fn empty(learning_rate: f64) -> Self {
Self {
vars: vec![],