diff options
Diffstat (limited to 'candle-nn/src/var_builder.rs')
-rw-r--r-- | candle-nn/src/var_builder.rs | 156 |
1 files changed, 143 insertions, 13 deletions
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index be1380b7..374260b0 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -1,7 +1,87 @@ -use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor}; +use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor, Var}; use safetensors::{slice::IndexOp, tensor::SafeTensors}; use std::collections::HashMap; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; + +/// A `VarMap` is a store that holds named variables. Variables can be retrieved from the stores +/// and new variables can be added by providing some initialization config in case they are +/// missing. +/// `VarMap` structures can be serialized in the safetensors format. +#[derive(Clone)] +pub struct VarMap { + data: Arc<Mutex<HashMap<String, Var>>>, +} + +impl VarMap { + /// Create a new empty `VarMap`. + #[allow(clippy::new_without_default)] + pub fn new() -> Self { + let data = Arc::new(Mutex::new(HashMap::new())); + Self { data } + } + + /// Retrieve all the variables currently stored in the map. + pub fn all_vars(&self) -> Vec<Var> { + let tensor_data = self.data.lock().unwrap(); + #[allow(clippy::map_clone)] + tensor_data.values().map(|c| c.clone()).collect::<Vec<_>>() + } + + /// Save the map in the safetensors format. + pub fn save<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> { + let tensor_data = self.data.lock().unwrap(); + let data = tensor_data.iter().map(|(k, v)| (k, v.as_tensor())); + safetensors::tensor::serialize_to_file(data, &None, path.as_ref())?; + Ok(()) + } + + /// Load some values from a safetensors file and modify the existing variables to have these + /// values. + /// + /// Note that values for variables that are currently not in the map are not kept. + pub fn load<P: AsRef<std::path::Path>>(&mut self, path: P) -> Result<()> { + let path = path.as_ref(); + let data = unsafe { candle::safetensors::MmapedFile::new(path)? }; + let data = data.deserialize()?; + let mut tensor_data = self.data.lock().unwrap(); + for (name, var) in tensor_data.iter_mut() { + match data.tensor(name) { + Ok(data) => { + let data: Tensor = data.load(var.device())?; + if let Err(err) = var.set(&data) { + candle::bail!("error setting {name} using data from {path:?}: {err}",) + } + } + Err(_) => candle::bail!("cannot find tensor for {name}"), + } + } + Ok(()) + } + + /// Retrieve or add a new variable. + pub fn get<S: Into<Shape>>( + &self, + shape: S, + path: &str, + init: crate::Init, + dtype: DType, + device: &Device, + ) -> Result<Tensor> { + let shape = shape.into(); + let mut tensor_data = self.data.lock().unwrap(); + if let Some(tensor) = tensor_data.get(path) { + let tensor_shape = tensor.shape(); + if &shape != tensor_shape { + candle::bail!("shape mismatch on {path}: {shape:?} <> {tensor_shape:?}") + } + return Ok(tensor.as_tensor().clone()); + } + let var = init.var(shape, dtype, device)?; + let tensor = var.as_tensor().clone(); + tensor_data.insert(path.to_string(), var); + Ok(tensor) + } +} // TODO: Maybe we would want the storage to be generic, e.g. with Box<dyn> to avoid too many // generics. @@ -13,6 +93,7 @@ enum Tensors<'a> { Npz(candle::npy::NpzTensors), TensorMap(HashMap<String, Tensor>), Zeros, + VarMap(VarMap), } struct TensorData<'a> { @@ -64,6 +145,14 @@ impl<'a> TensorData<'a> { dtype, }) } + + fn from_varmap(varmap: &VarMap, dtype: DType, device: &Device) -> Self { + Self { + tensors: Tensors::VarMap(varmap.clone()), + device: device.clone(), + dtype, + } + } } #[derive(Clone)] @@ -99,6 +188,14 @@ impl<'a> VarBuilder<'a> { } } + pub fn from_varmap(varmap: &VarMap, dtype: DType, device: &Device) -> Self { + let data = TensorData::from_varmap(varmap, dtype, device); + Self { + data: Arc::new(data), + path: vec![], + } + } + pub fn from_npz<P: AsRef<std::path::Path>>( file: P, dtype: DType, @@ -154,11 +251,7 @@ impl<'a> VarBuilder<'a> { world_size: usize, ) -> Result<Tensor> { let data = self.data.as_ref(); - let path = if self.path.is_empty() { - tensor_name.to_string() - } else { - [&self.path.join("."), tensor_name].join(".") - }; + let path = self.path(tensor_name); let tensor = match &self.data.tensors { Tensors::SafeTensorWithRouting { routing, @@ -205,19 +298,16 @@ impl<'a> VarBuilder<'a> { let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect(); Tensor::from_raw_buffer(&raw, dtype, &shape, &data.device)? } - _ => unimplemented!(), + _ => candle::bail!("get_sharded is only available for safetensors"), }; Ok(tensor) } + /// Retrieve the tensor associted with the current name and path. pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> Result<Tensor> { let data = self.data.as_ref(); let s: Shape = s.into(); - let path = if self.path.is_empty() { - tensor_name.to_string() - } else { - [&self.path.join("."), tensor_name].join(".") - }; + let path = self.path(tensor_name); let tensor = match &self.data.tensors { Tensors::Zeros => Tensor::zeros(&s, data.dtype, &data.device)?.contiguous()?, Tensors::TensorMap(ts) => ts @@ -229,6 +319,18 @@ impl<'a> VarBuilder<'a> { .bt() })? .clone(), + Tensors::VarMap(varmap) => { + let data = varmap.data.lock().unwrap(); + data.get(&path) + .ok_or_else(|| { + Error::CannotFindTensor { + path: path.to_string(), + } + .bt() + })? + .as_tensor() + .clone() + } Tensors::Npz(npz) => npz.get(&path)?.ok_or_else(|| { Error::CannotFindTensor { path: path.to_string(), @@ -261,4 +363,32 @@ impl<'a> VarBuilder<'a> { } Ok(tensor) } + + /// Retrieve the tensor associted with the current name and path or initialize a new tensor if + /// it's missing. + /// + /// Tensor initialization is only available if the `VarBuilder` is backed by a `VarMap`. + pub fn get_or_init<S: Into<Shape>>( + &self, + s: S, + tensor_name: &str, + init: crate::Init, + ) -> Result<Tensor> { + let data = self.data.as_ref(); + match &self.data.tensors { + Tensors::VarMap(varmap) => { + let path = self.path(tensor_name); + varmap.get(s, &path, init, data.dtype, &data.device) + } + _ => self.get(s, tensor_name), + } + } + + fn path(&self, tensor_name: &str) -> String { + if self.path.is_empty() { + tensor_name.to_string() + } else { + [&self.path.join("."), tensor_name].join(".") + } + } } |