diff options
Diffstat (limited to 'candle-nn/src')
-rw-r--r-- | candle-nn/src/optim.rs | 5 |
1 files changed, 5 insertions, 0 deletions
diff --git a/candle-nn/src/optim.rs b/candle-nn/src/optim.rs index 4294d75e..7704bb48 100644 --- a/candle-nn/src/optim.rs +++ b/candle-nn/src/optim.rs @@ -41,6 +41,10 @@ impl Optimizer for SGD { type Config = f64; fn new(vars: Vec<Var>, learning_rate: f64) -> Result<Self> { + let vars = vars + .into_iter() + .filter(|var| var.dtype().is_float()) + .collect(); Ok(Self { vars, learning_rate, @@ -116,6 +120,7 @@ impl Optimizer for AdamW { fn new(vars: Vec<Var>, params: ParamsAdamW) -> Result<Self> { let vars = vars .into_iter() + .filter(|var| var.dtype().is_float()) .map(|var| { let dtype = var.dtype(); let shape = var.shape(); |