summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-nn/src/var_map.rs38
1 files changed, 38 insertions, 0 deletions
diff --git a/candle-nn/src/var_map.rs b/candle-nn/src/var_map.rs
index f61fad23..c17558b7 100644
--- a/candle-nn/src/var_map.rs
+++ b/candle-nn/src/var_map.rs
@@ -57,6 +57,44 @@ impl VarMap {
Ok(())
}
+ /// Set a named variable to some value.
+ pub fn set_one<K: AsRef<String>, V: AsRef<Tensor>>(&mut self, name: K, value: V) -> Result<()> {
+ let tensor_data = self.data.lock().unwrap();
+ let name = name.as_ref();
+ match tensor_data.get(name) {
+ None => candle::bail!("cannot find {name} in VarMap"),
+ Some(var) => {
+ if let Err(err) = var.set(value.as_ref()) {
+ candle::bail!("error setting {name}: {err}",)
+ }
+ }
+ }
+ Ok(())
+ }
+
+ /// Set some named variables to some values.
+ ///
+ /// If an error is returned, some of the variables might have already been set to their new
+ /// values.
+ pub fn set<I: Iterator<Item = (K, V)>, K: AsRef<String>, V: AsRef<Tensor>>(
+ &mut self,
+ iter: I,
+ ) -> Result<()> {
+ let tensor_data = self.data.lock().unwrap();
+ for (name, value) in iter {
+ let name = name.as_ref();
+ match tensor_data.get(name) {
+ None => candle::bail!("cannot find {name} in VarMap"),
+ Some(var) => {
+ if let Err(err) = var.set(value.as_ref()) {
+ candle::bail!("error setting {name}: {err}",)
+ }
+ }
+ }
+ }
+ Ok(())
+ }
+
/// Retrieve or add a new variable.
pub fn get<S: Into<Shape>>(
&self,