summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-10 10:05:41 +0200
committerGitHub <noreply@github.com>2023-10-10 09:05:41 +0100
commit9fea56d28e5f99529da8ed8df1eb508b0f163cc3 (patch)
tree70562ee56e1c7ef30d289bc18a63c84d5e1e71f4
parentbc3351bce4ce0ad24c69f872ffd51dc829fe88c8 (diff)
downloadcandle-9fea56d28e5f99529da8ed8df1eb508b0f163cc3.tar.gz
candle-9fea56d28e5f99529da8ed8df1eb508b0f163cc3.tar.bz2
candle-9fea56d28e5f99529da8ed8df1eb508b0f163cc3.zip
Only optimize float tensors. (#1069)
-rw-r--r--candle-core/src/dtype.rs14
-rw-r--r--candle-nn/src/optim.rs5
2 files changed, 19 insertions, 0 deletions
diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs
index c7a1567f..94ca57d8 100644
--- a/candle-core/src/dtype.rs
+++ b/candle-core/src/dtype.rs
@@ -67,6 +67,20 @@ impl DType {
Self::F64 => 8,
}
}
+
+ pub fn is_int(&self) -> bool {
+ match self {
+ Self::U8 | Self::U32 | Self::I64 => true,
+ Self::BF16 | Self::F16 | Self::F32 | Self::F64 => false,
+ }
+ }
+
+ pub fn is_float(&self) -> bool {
+ match self {
+ Self::U8 | Self::U32 | Self::I64 => false,
+ Self::BF16 | Self::F16 | Self::F32 | Self::F64 => true,
+ }
+ }
}
pub trait WithDType:
diff --git a/candle-nn/src/optim.rs b/candle-nn/src/optim.rs
index 4294d75e..7704bb48 100644
--- a/candle-nn/src/optim.rs
+++ b/candle-nn/src/optim.rs
@@ -41,6 +41,10 @@ impl Optimizer for SGD {
type Config = f64;
fn new(vars: Vec<Var>, learning_rate: f64) -> Result<Self> {
+ let vars = vars
+ .into_iter()
+ .filter(|var| var.dtype().is_float())
+ .collect();
Ok(Self {
vars,
learning_rate,
@@ -116,6 +120,7 @@ impl Optimizer for AdamW {
fn new(vars: Vec<Var>, params: ParamsAdamW) -> Result<Self> {
let vars = vars
.into_iter()
+ .filter(|var| var.dtype().is_float())
.map(|var| {
let dtype = var.dtype();
let shape = var.shape();