diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-10 10:05:41 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-10 09:05:41 +0100 |
commit | 9fea56d28e5f99529da8ed8df1eb508b0f163cc3 (patch) | |
tree | 70562ee56e1c7ef30d289bc18a63c84d5e1e71f4 /candle-nn/src | |
parent | bc3351bce4ce0ad24c69f872ffd51dc829fe88c8 (diff) | |
download | candle-9fea56d28e5f99529da8ed8df1eb508b0f163cc3.tar.gz candle-9fea56d28e5f99529da8ed8df1eb508b0f163cc3.tar.bz2 candle-9fea56d28e5f99529da8ed8df1eb508b0f163cc3.zip |
Only optimize float tensors. (#1069)
Diffstat (limited to 'candle-nn/src')
-rw-r--r-- | candle-nn/src/optim.rs | 5 |
1 files changed, 5 insertions, 0 deletions
diff --git a/candle-nn/src/optim.rs b/candle-nn/src/optim.rs index 4294d75e..7704bb48 100644 --- a/candle-nn/src/optim.rs +++ b/candle-nn/src/optim.rs @@ -41,6 +41,10 @@ impl Optimizer for SGD { type Config = f64; fn new(vars: Vec<Var>, learning_rate: f64) -> Result<Self> { + let vars = vars + .into_iter() + .filter(|var| var.dtype().is_float()) + .collect(); Ok(Self { vars, learning_rate, @@ -116,6 +120,7 @@ impl Optimizer for AdamW { fn new(vars: Vec<Var>, params: ParamsAdamW) -> Result<Self> { let vars = vars .into_iter() + .filter(|var| var.dtype().is_float()) .map(|var| { let dtype = var.dtype(); let shape = var.shape(); |