diff options
Diffstat (limited to 'candle-nn/src')
-rw-r--r-- | candle-nn/src/embedding.rs | 12 | ||||
-rw-r--r-- | candle-nn/src/layer_norm.rs | 6 | ||||
-rw-r--r-- | candle-nn/src/lib.rs | 8 | ||||
-rw-r--r-- | candle-nn/src/linear.rs | 23 | ||||
-rw-r--r-- | candle-nn/src/optim.rs | 9 | ||||
-rw-r--r-- | candle-nn/src/var_builder.rs | 156 |
6 files changed, 195 insertions, 19 deletions
diff --git a/candle-nn/src/embedding.rs b/candle-nn/src/embedding.rs index a0a853b0..050123be 100644 --- a/candle-nn/src/embedding.rs +++ b/candle-nn/src/embedding.rs @@ -28,3 +28,15 @@ impl Embedding { Ok(values) } } + +pub fn embedding(in_size: usize, out_size: usize, vb: crate::VarBuilder) -> Result<Embedding> { + let embeddings = vb.get_or_init( + (in_size, out_size), + "weight", + crate::Init::Randn { + mean: 0., + stdev: 1., + }, + )?; + Ok(Embedding::new(embeddings, out_size)) +} diff --git a/candle-nn/src/layer_norm.rs b/candle-nn/src/layer_norm.rs index 8f8544bb..668f9a4b 100644 --- a/candle-nn/src/layer_norm.rs +++ b/candle-nn/src/layer_norm.rs @@ -62,3 +62,9 @@ impl LayerNorm { Ok(x) } } + +pub fn layer_norm(size: usize, eps: f64, vb: crate::VarBuilder) -> Result<LayerNorm> { + let weight = vb.get_or_init(size, "weight", crate::Init::Const(1.))?; + let bias = vb.get_or_init(size, "bias", crate::Init::Const(0.))?; + Ok(LayerNorm::new(weight, bias, eps)) +} diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index e8086238..45edfc46 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -15,9 +15,9 @@ pub mod vision; pub use activation::Activation; pub use conv::{Conv1d, Conv1dConfig}; -pub use embedding::Embedding; +pub use embedding::{embedding, Embedding}; pub use init::Init; -pub use layer_norm::LayerNorm; -pub use linear::Linear; +pub use layer_norm::{layer_norm, LayerNorm}; +pub use linear::{linear, linear_no_bias, Linear}; pub use optim::SGD; -pub use var_builder::VarBuilder; +pub use var_builder::{VarBuilder, VarMap}; diff --git a/candle-nn/src/linear.rs b/candle-nn/src/linear.rs index 943011c9..a0bd925a 100644 --- a/candle-nn/src/linear.rs +++ b/candle-nn/src/linear.rs @@ -17,7 +17,7 @@ //! assert_eq!(ys.to_vec2::<f32>()?, &[[210.0, 430.0, 650.0]]); //! # Ok(()) } //! ``` -use candle::Tensor; +use candle::{Result, Tensor}; #[derive(Debug)] pub struct Linear { @@ -42,3 +42,24 @@ impl Linear { } } } + +/// Create or initialize a new linear layer. +/// +/// This uses some default names for weight and biases, namely `"weight"` and `"bias"`. +pub fn linear(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> { + let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; + let ws = vs.get_or_init((out_dim, in_dim), "weight", init_ws)?; + let bound = 1. / (in_dim as f64).sqrt(); + let init_bs = crate::Init::Uniform { + lo: -bound, + up: bound, + }; + let bs = vs.get_or_init(out_dim, "bias", init_bs)?; + Ok(Linear::new(ws, Some(bs))) +} + +pub fn linear_no_bias(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> { + let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; + let ws = vs.get_or_init((out_dim, in_dim), "weight", init_ws)?; + Ok(Linear::new(ws, None)) +} diff --git a/candle-nn/src/optim.rs b/candle-nn/src/optim.rs index a8b5b370..d20ef284 100644 --- a/candle-nn/src/optim.rs +++ b/candle-nn/src/optim.rs @@ -8,7 +8,7 @@ pub struct SGD { } impl SGD { - pub fn new(vars: &[&Var], learning_rate: f64) -> Self { + pub fn from_slice(vars: &[&Var], learning_rate: f64) -> Self { let vars: Vec<_> = vars.iter().map(|&v| v.clone()).collect(); Self { vars, @@ -16,6 +16,13 @@ impl SGD { } } + pub fn new(vars: Vec<Var>, learning_rate: f64) -> Self { + Self { + vars, + learning_rate, + } + } + pub fn empty(learning_rate: f64) -> Self { Self { vars: vec![], diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index be1380b7..374260b0 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -1,7 +1,87 @@ -use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor}; +use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor, Var}; use safetensors::{slice::IndexOp, tensor::SafeTensors}; use std::collections::HashMap; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; + +/// A `VarMap` is a store that holds named variables. Variables can be retrieved from the stores +/// and new variables can be added by providing some initialization config in case they are +/// missing. +/// `VarMap` structures can be serialized in the safetensors format. +#[derive(Clone)] +pub struct VarMap { + data: Arc<Mutex<HashMap<String, Var>>>, +} + +impl VarMap { + /// Create a new empty `VarMap`. + #[allow(clippy::new_without_default)] + pub fn new() -> Self { + let data = Arc::new(Mutex::new(HashMap::new())); + Self { data } + } + + /// Retrieve all the variables currently stored in the map. + pub fn all_vars(&self) -> Vec<Var> { + let tensor_data = self.data.lock().unwrap(); + #[allow(clippy::map_clone)] + tensor_data.values().map(|c| c.clone()).collect::<Vec<_>>() + } + + /// Save the map in the safetensors format. + pub fn save<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> { + let tensor_data = self.data.lock().unwrap(); + let data = tensor_data.iter().map(|(k, v)| (k, v.as_tensor())); + safetensors::tensor::serialize_to_file(data, &None, path.as_ref())?; + Ok(()) + } + + /// Load some values from a safetensors file and modify the existing variables to have these + /// values. + /// + /// Note that values for variables that are currently not in the map are not kept. + pub fn load<P: AsRef<std::path::Path>>(&mut self, path: P) -> Result<()> { + let path = path.as_ref(); + let data = unsafe { candle::safetensors::MmapedFile::new(path)? }; + let data = data.deserialize()?; + let mut tensor_data = self.data.lock().unwrap(); + for (name, var) in tensor_data.iter_mut() { + match data.tensor(name) { + Ok(data) => { + let data: Tensor = data.load(var.device())?; + if let Err(err) = var.set(&data) { + candle::bail!("error setting {name} using data from {path:?}: {err}",) + } + } + Err(_) => candle::bail!("cannot find tensor for {name}"), + } + } + Ok(()) + } + + /// Retrieve or add a new variable. + pub fn get<S: Into<Shape>>( + &self, + shape: S, + path: &str, + init: crate::Init, + dtype: DType, + device: &Device, + ) -> Result<Tensor> { + let shape = shape.into(); + let mut tensor_data = self.data.lock().unwrap(); + if let Some(tensor) = tensor_data.get(path) { + let tensor_shape = tensor.shape(); + if &shape != tensor_shape { + candle::bail!("shape mismatch on {path}: {shape:?} <> {tensor_shape:?}") + } + return Ok(tensor.as_tensor().clone()); + } + let var = init.var(shape, dtype, device)?; + let tensor = var.as_tensor().clone(); + tensor_data.insert(path.to_string(), var); + Ok(tensor) + } +} // TODO: Maybe we would want the storage to be generic, e.g. with Box<dyn> to avoid too many // generics. @@ -13,6 +93,7 @@ enum Tensors<'a> { Npz(candle::npy::NpzTensors), TensorMap(HashMap<String, Tensor>), Zeros, + VarMap(VarMap), } struct TensorData<'a> { @@ -64,6 +145,14 @@ impl<'a> TensorData<'a> { dtype, }) } + + fn from_varmap(varmap: &VarMap, dtype: DType, device: &Device) -> Self { + Self { + tensors: Tensors::VarMap(varmap.clone()), + device: device.clone(), + dtype, + } + } } #[derive(Clone)] @@ -99,6 +188,14 @@ impl<'a> VarBuilder<'a> { } } + pub fn from_varmap(varmap: &VarMap, dtype: DType, device: &Device) -> Self { + let data = TensorData::from_varmap(varmap, dtype, device); + Self { + data: Arc::new(data), + path: vec![], + } + } + pub fn from_npz<P: AsRef<std::path::Path>>( file: P, dtype: DType, @@ -154,11 +251,7 @@ impl<'a> VarBuilder<'a> { world_size: usize, ) -> Result<Tensor> { let data = self.data.as_ref(); - let path = if self.path.is_empty() { - tensor_name.to_string() - } else { - [&self.path.join("."), tensor_name].join(".") - }; + let path = self.path(tensor_name); let tensor = match &self.data.tensors { Tensors::SafeTensorWithRouting { routing, @@ -205,19 +298,16 @@ impl<'a> VarBuilder<'a> { let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect(); Tensor::from_raw_buffer(&raw, dtype, &shape, &data.device)? } - _ => unimplemented!(), + _ => candle::bail!("get_sharded is only available for safetensors"), }; Ok(tensor) } + /// Retrieve the tensor associted with the current name and path. pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> Result<Tensor> { let data = self.data.as_ref(); let s: Shape = s.into(); - let path = if self.path.is_empty() { - tensor_name.to_string() - } else { - [&self.path.join("."), tensor_name].join(".") - }; + let path = self.path(tensor_name); let tensor = match &self.data.tensors { Tensors::Zeros => Tensor::zeros(&s, data.dtype, &data.device)?.contiguous()?, Tensors::TensorMap(ts) => ts @@ -229,6 +319,18 @@ impl<'a> VarBuilder<'a> { .bt() })? .clone(), + Tensors::VarMap(varmap) => { + let data = varmap.data.lock().unwrap(); + data.get(&path) + .ok_or_else(|| { + Error::CannotFindTensor { + path: path.to_string(), + } + .bt() + })? + .as_tensor() + .clone() + } Tensors::Npz(npz) => npz.get(&path)?.ok_or_else(|| { Error::CannotFindTensor { path: path.to_string(), @@ -261,4 +363,32 @@ impl<'a> VarBuilder<'a> { } Ok(tensor) } + + /// Retrieve the tensor associted with the current name and path or initialize a new tensor if + /// it's missing. + /// + /// Tensor initialization is only available if the `VarBuilder` is backed by a `VarMap`. + pub fn get_or_init<S: Into<Shape>>( + &self, + s: S, + tensor_name: &str, + init: crate::Init, + ) -> Result<Tensor> { + let data = self.data.as_ref(); + match &self.data.tensors { + Tensors::VarMap(varmap) => { + let path = self.path(tensor_name); + varmap.get(s, &path, init, data.dtype, &data.device) + } + _ => self.get(s, tensor_name), + } + } + + fn path(&self, tensor_name: &str) -> String { + if self.path.is_empty() { + tensor_name.to_string() + } else { + [&self.path.join("."), tensor_name].join(".") + } + } } |