summaryrefslogtreecommitdiff
path: root/candle-nn/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/src')
-rw-r--r--candle-nn/src/optim.rs5
1 files changed, 5 insertions, 0 deletions
diff --git a/candle-nn/src/optim.rs b/candle-nn/src/optim.rs
index 4294d75e..7704bb48 100644
--- a/candle-nn/src/optim.rs
+++ b/candle-nn/src/optim.rs
@@ -41,6 +41,10 @@ impl Optimizer for SGD {
type Config = f64;
fn new(vars: Vec<Var>, learning_rate: f64) -> Result<Self> {
+ let vars = vars
+ .into_iter()
+ .filter(|var| var.dtype().is_float())
+ .collect();
Ok(Self {
vars,
learning_rate,
@@ -116,6 +120,7 @@ impl Optimizer for AdamW {
fn new(vars: Vec<Var>, params: ParamsAdamW) -> Result<Self> {
let vars = vars
.into_iter()
+ .filter(|var| var.dtype().is_float())
.map(|var| {
let dtype = var.dtype();
let shape = var.shape();