summaryrefslogtreecommitdiff
path: root/candle-nn/src/var_builder.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/src/var_builder.rs')
-rw-r--r--candle-nn/src/var_builder.rs156
1 files changed, 143 insertions, 13 deletions
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs
index be1380b7..374260b0 100644
--- a/candle-nn/src/var_builder.rs
+++ b/candle-nn/src/var_builder.rs
@@ -1,7 +1,87 @@
-use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor};
+use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor, Var};
use safetensors::{slice::IndexOp, tensor::SafeTensors};
use std::collections::HashMap;
-use std::sync::Arc;
+use std::sync::{Arc, Mutex};
+
+/// A `VarMap` is a store that holds named variables. Variables can be retrieved from the stores
+/// and new variables can be added by providing some initialization config in case they are
+/// missing.
+/// `VarMap` structures can be serialized in the safetensors format.
+#[derive(Clone)]
+pub struct VarMap {
+ data: Arc<Mutex<HashMap<String, Var>>>,
+}
+
+impl VarMap {
+ /// Create a new empty `VarMap`.
+ #[allow(clippy::new_without_default)]
+ pub fn new() -> Self {
+ let data = Arc::new(Mutex::new(HashMap::new()));
+ Self { data }
+ }
+
+ /// Retrieve all the variables currently stored in the map.
+ pub fn all_vars(&self) -> Vec<Var> {
+ let tensor_data = self.data.lock().unwrap();
+ #[allow(clippy::map_clone)]
+ tensor_data.values().map(|c| c.clone()).collect::<Vec<_>>()
+ }
+
+ /// Save the map in the safetensors format.
+ pub fn save<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
+ let tensor_data = self.data.lock().unwrap();
+ let data = tensor_data.iter().map(|(k, v)| (k, v.as_tensor()));
+ safetensors::tensor::serialize_to_file(data, &None, path.as_ref())?;
+ Ok(())
+ }
+
+ /// Load some values from a safetensors file and modify the existing variables to have these
+ /// values.
+ ///
+ /// Note that values for variables that are currently not in the map are not kept.
+ pub fn load<P: AsRef<std::path::Path>>(&mut self, path: P) -> Result<()> {
+ let path = path.as_ref();
+ let data = unsafe { candle::safetensors::MmapedFile::new(path)? };
+ let data = data.deserialize()?;
+ let mut tensor_data = self.data.lock().unwrap();
+ for (name, var) in tensor_data.iter_mut() {
+ match data.tensor(name) {
+ Ok(data) => {
+ let data: Tensor = data.load(var.device())?;
+ if let Err(err) = var.set(&data) {
+ candle::bail!("error setting {name} using data from {path:?}: {err}",)
+ }
+ }
+ Err(_) => candle::bail!("cannot find tensor for {name}"),
+ }
+ }
+ Ok(())
+ }
+
+ /// Retrieve or add a new variable.
+ pub fn get<S: Into<Shape>>(
+ &self,
+ shape: S,
+ path: &str,
+ init: crate::Init,
+ dtype: DType,
+ device: &Device,
+ ) -> Result<Tensor> {
+ let shape = shape.into();
+ let mut tensor_data = self.data.lock().unwrap();
+ if let Some(tensor) = tensor_data.get(path) {
+ let tensor_shape = tensor.shape();
+ if &shape != tensor_shape {
+ candle::bail!("shape mismatch on {path}: {shape:?} <> {tensor_shape:?}")
+ }
+ return Ok(tensor.as_tensor().clone());
+ }
+ let var = init.var(shape, dtype, device)?;
+ let tensor = var.as_tensor().clone();
+ tensor_data.insert(path.to_string(), var);
+ Ok(tensor)
+ }
+}
// TODO: Maybe we would want the storage to be generic, e.g. with Box<dyn> to avoid too many
// generics.
@@ -13,6 +93,7 @@ enum Tensors<'a> {
Npz(candle::npy::NpzTensors),
TensorMap(HashMap<String, Tensor>),
Zeros,
+ VarMap(VarMap),
}
struct TensorData<'a> {
@@ -64,6 +145,14 @@ impl<'a> TensorData<'a> {
dtype,
})
}
+
+ fn from_varmap(varmap: &VarMap, dtype: DType, device: &Device) -> Self {
+ Self {
+ tensors: Tensors::VarMap(varmap.clone()),
+ device: device.clone(),
+ dtype,
+ }
+ }
}
#[derive(Clone)]
@@ -99,6 +188,14 @@ impl<'a> VarBuilder<'a> {
}
}
+ pub fn from_varmap(varmap: &VarMap, dtype: DType, device: &Device) -> Self {
+ let data = TensorData::from_varmap(varmap, dtype, device);
+ Self {
+ data: Arc::new(data),
+ path: vec![],
+ }
+ }
+
pub fn from_npz<P: AsRef<std::path::Path>>(
file: P,
dtype: DType,
@@ -154,11 +251,7 @@ impl<'a> VarBuilder<'a> {
world_size: usize,
) -> Result<Tensor> {
let data = self.data.as_ref();
- let path = if self.path.is_empty() {
- tensor_name.to_string()
- } else {
- [&self.path.join("."), tensor_name].join(".")
- };
+ let path = self.path(tensor_name);
let tensor = match &self.data.tensors {
Tensors::SafeTensorWithRouting {
routing,
@@ -205,19 +298,16 @@ impl<'a> VarBuilder<'a> {
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
Tensor::from_raw_buffer(&raw, dtype, &shape, &data.device)?
}
- _ => unimplemented!(),
+ _ => candle::bail!("get_sharded is only available for safetensors"),
};
Ok(tensor)
}
+ /// Retrieve the tensor associted with the current name and path.
pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> Result<Tensor> {
let data = self.data.as_ref();
let s: Shape = s.into();
- let path = if self.path.is_empty() {
- tensor_name.to_string()
- } else {
- [&self.path.join("."), tensor_name].join(".")
- };
+ let path = self.path(tensor_name);
let tensor = match &self.data.tensors {
Tensors::Zeros => Tensor::zeros(&s, data.dtype, &data.device)?.contiguous()?,
Tensors::TensorMap(ts) => ts
@@ -229,6 +319,18 @@ impl<'a> VarBuilder<'a> {
.bt()
})?
.clone(),
+ Tensors::VarMap(varmap) => {
+ let data = varmap.data.lock().unwrap();
+ data.get(&path)
+ .ok_or_else(|| {
+ Error::CannotFindTensor {
+ path: path.to_string(),
+ }
+ .bt()
+ })?
+ .as_tensor()
+ .clone()
+ }
Tensors::Npz(npz) => npz.get(&path)?.ok_or_else(|| {
Error::CannotFindTensor {
path: path.to_string(),
@@ -261,4 +363,32 @@ impl<'a> VarBuilder<'a> {
}
Ok(tensor)
}
+
+ /// Retrieve the tensor associted with the current name and path or initialize a new tensor if
+ /// it's missing.
+ ///
+ /// Tensor initialization is only available if the `VarBuilder` is backed by a `VarMap`.
+ pub fn get_or_init<S: Into<Shape>>(
+ &self,
+ s: S,
+ tensor_name: &str,
+ init: crate::Init,
+ ) -> Result<Tensor> {
+ let data = self.data.as_ref();
+ match &self.data.tensors {
+ Tensors::VarMap(varmap) => {
+ let path = self.path(tensor_name);
+ varmap.get(s, &path, init, data.dtype, &data.device)
+ }
+ _ => self.get(s, tensor_name),
+ }
+ }
+
+ fn path(&self, tensor_name: &str) -> String {
+ if self.path.is_empty() {
+ tensor_name.to_string()
+ } else {
+ [&self.path.join("."), tensor_name].join(".")
+ }
+ }
}