summaryrefslogtreecommitdiff
path: root/candle-core/src/variable.rs
diff options
context:
space:
mode:
authorJeffrey Dallatezza <jeffreydallatezza@gmail.com>2024-04-29 02:21:53 -0700
committerGitHub <noreply@github.com>2024-04-29 11:21:53 +0200
commita0d03aded1b8c4cfe96f7d6490f5c709c31b76f0 (patch)
treed9c1020da5b3f9fb2faff97edfa60251c1c4952e /candle-core/src/variable.rs
parent3bbb88fcb463a6bdbb0e71c7b2d211dd02681493 (diff)
downloadcandle-a0d03aded1b8c4cfe96f7d6490f5c709c31b76f0.tar.gz
candle-a0d03aded1b8c4cfe96f7d6490f5c709c31b76f0.tar.bz2
candle-a0d03aded1b8c4cfe96f7d6490f5c709c31b76f0.zip
Bug Fix: When converting a tensor to a variable, clone if the tensor is already a variable. (#2124)
* When converting a tensor to a variable, clone if the tensor is already a variable. * Add a test to ensure training a batch norm works with VarMaps --------- Co-authored-by: Jeffrey Dallatezza <jeffreydallatezza@Jeffreys-Laptop.local>
Diffstat (limited to 'candle-core/src/variable.rs')
-rw-r--r--candle-core/src/variable.rs9
1 files changed, 7 insertions, 2 deletions
diff --git a/candle-core/src/variable.rs b/candle-core/src/variable.rs
index bdf8da4a..1e4880e5 100644
--- a/candle-core/src/variable.rs
+++ b/candle-core/src/variable.rs
@@ -34,9 +34,14 @@ impl Var {
Ok(Self(inner))
}
+ // Convert a tensor to a variable, if the tensor is already a variable then it is returned as is.
pub fn from_tensor(t: &Tensor) -> Result<Self> {
- let inner = t.make_var()?;
- Ok(Self(inner))
+ if t.is_variable() {
+ Ok(Self(t.clone()))
+ } else {
+ let inner = t.make_var()?;
+ Ok(Self(inner))
+ }
}
pub fn rand_f64<S: Into<Shape>>(