summaryrefslogtreecommitdiff
path: root/candle-nn/src/var_map.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-20 14:25:07 +0100
committerGitHub <noreply@github.com>2023-08-20 14:25:07 +0100
commitd2622a8160a429835ab0a3aa2145ab8c9de4cdd9 (patch)
treef056cae6fd1c53337368958180eddcaeb17b75bc /candle-nn/src/var_map.rs
parent2fcb386f17558f56c8b8e24fbc36c9d3686e73ba (diff)
downloadcandle-d2622a8160a429835ab0a3aa2145ab8c9de4cdd9.tar.gz
candle-d2622a8160a429835ab0a3aa2145ab8c9de4cdd9.tar.bz2
candle-d2622a8160a429835ab0a3aa2145ab8c9de4cdd9.zip
Move the VarMap to a separate file (#525)
* Move the var-map struct in a separate file. * Fix some typos.
Diffstat (limited to 'candle-nn/src/var_map.rs')
-rw-r--r--candle-nn/src/var_map.rs87
1 files changed, 87 insertions, 0 deletions
diff --git a/candle-nn/src/var_map.rs b/candle-nn/src/var_map.rs
new file mode 100644
index 00000000..f61fad23
--- /dev/null
+++ b/candle-nn/src/var_map.rs
@@ -0,0 +1,87 @@
+use candle::{safetensors::Load, DType, Device, Result, Shape, Tensor, Var};
+use std::collections::HashMap;
+use std::sync::{Arc, Mutex};
+
+/// A `VarMap` is a store that holds named variables. Variables can be retrieved from the stores
+/// and new variables can be added by providing some initialization config in case they are
+/// missing.
+/// `VarMap` structures can be serialized in the safetensors format.
+#[derive(Clone)]
+pub struct VarMap {
+ data: Arc<Mutex<HashMap<String, Var>>>,
+}
+
+impl VarMap {
+ /// Create a new empty `VarMap`.
+ #[allow(clippy::new_without_default)]
+ pub fn new() -> Self {
+ let data = Arc::new(Mutex::new(HashMap::new()));
+ Self { data }
+ }
+
+ /// Retrieve all the variables currently stored in the map.
+ pub fn all_vars(&self) -> Vec<Var> {
+ let tensor_data = self.data.lock().unwrap();
+ #[allow(clippy::map_clone)]
+ tensor_data.values().map(|c| c.clone()).collect::<Vec<_>>()
+ }
+
+ /// Save the map in the safetensors format.
+ pub fn save<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
+ let tensor_data = self.data.lock().unwrap();
+ let data = tensor_data.iter().map(|(k, v)| (k, v.as_tensor()));
+ safetensors::tensor::serialize_to_file(data, &None, path.as_ref())?;
+ Ok(())
+ }
+
+ /// Load some values from a safetensors file and modify the existing variables to have these
+ /// values.
+ ///
+ /// Note that values for variables that are currently not in the map are not kept.
+ pub fn load<P: AsRef<std::path::Path>>(&mut self, path: P) -> Result<()> {
+ let path = path.as_ref();
+ let data = unsafe { candle::safetensors::MmapedFile::new(path)? };
+ let data = data.deserialize()?;
+ let mut tensor_data = self.data.lock().unwrap();
+ for (name, var) in tensor_data.iter_mut() {
+ match data.tensor(name) {
+ Ok(data) => {
+ let data: Tensor = data.load(var.device())?;
+ if let Err(err) = var.set(&data) {
+ candle::bail!("error setting {name} using data from {path:?}: {err}",)
+ }
+ }
+ Err(_) => candle::bail!("cannot find tensor for {name}"),
+ }
+ }
+ Ok(())
+ }
+
+ /// Retrieve or add a new variable.
+ pub fn get<S: Into<Shape>>(
+ &self,
+ shape: S,
+ path: &str,
+ init: crate::Init,
+ dtype: DType,
+ device: &Device,
+ ) -> Result<Tensor> {
+ let shape = shape.into();
+ let mut tensor_data = self.data.lock().unwrap();
+ if let Some(tensor) = tensor_data.get(path) {
+ let tensor_shape = tensor.shape();
+ if &shape != tensor_shape {
+ candle::bail!("shape mismatch on {path}: {shape:?} <> {tensor_shape:?}")
+ }
+ return Ok(tensor.as_tensor().clone());
+ }
+ let var = init.var(shape, dtype, device)?;
+ let tensor = var.as_tensor().clone();
+ tensor_data.insert(path.to_string(), var);
+ Ok(tensor)
+ }
+
+ pub fn data(&self) -> &Mutex<HashMap<String, Var>> {
+ &self.data
+ }
+}