diff options
author | Jeffrey Dallatezza <jeffreydallatezza@gmail.com> | 2024-04-29 02:21:53 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-29 11:21:53 +0200 |
commit | a0d03aded1b8c4cfe96f7d6490f5c709c31b76f0 (patch) | |
tree | d9c1020da5b3f9fb2faff97edfa60251c1c4952e /candle-core/src/variable.rs | |
parent | 3bbb88fcb463a6bdbb0e71c7b2d211dd02681493 (diff) | |
download | candle-a0d03aded1b8c4cfe96f7d6490f5c709c31b76f0.tar.gz candle-a0d03aded1b8c4cfe96f7d6490f5c709c31b76f0.tar.bz2 candle-a0d03aded1b8c4cfe96f7d6490f5c709c31b76f0.zip |
Bug Fix: When converting a tensor to a variable, clone if the tensor is already a variable. (#2124)
* When converting a tensor to a variable, clone if the tensor is already a variable.
* Add a test to ensure training a batch norm works with VarMaps
---------
Co-authored-by: Jeffrey Dallatezza <jeffreydallatezza@Jeffreys-Laptop.local>
Diffstat (limited to 'candle-core/src/variable.rs')
-rw-r--r-- | candle-core/src/variable.rs | 9 |
1 files changed, 7 insertions, 2 deletions
diff --git a/candle-core/src/variable.rs b/candle-core/src/variable.rs index bdf8da4a..1e4880e5 100644 --- a/candle-core/src/variable.rs +++ b/candle-core/src/variable.rs @@ -34,9 +34,14 @@ impl Var { Ok(Self(inner)) } + // Convert a tensor to a variable, if the tensor is already a variable then it is returned as is. pub fn from_tensor(t: &Tensor) -> Result<Self> { - let inner = t.make_var()?; - Ok(Self(inner)) + if t.is_variable() { + Ok(Self(t.clone())) + } else { + let inner = t.make_var()?; + Ok(Self(inner)) + } } pub fn rand_f64<S: Into<Shape>>( |