diff options
Diffstat (limited to 'candle-nn/src')
-rw-r--r-- | candle-nn/src/var_builder.rs | 103 |
1 files changed, 84 insertions, 19 deletions
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 203640b0..d71b5822 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -1,53 +1,118 @@ use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor}; use std::collections::HashMap; +use std::sync::Arc; -pub struct VarBuilder<'a> { - safetensors: Option<(HashMap<String, usize>, Vec<SafeTensors<'a>>)>, +struct SafeTensorWithRouting<'a> { + routing: HashMap<String, usize>, + safetensors: Vec<SafeTensors<'a>>, +} + +struct TensorData<'a> { + // TODO: Make this part generic, probably via some Box<dyn> to avoid too much generics. + safetensors: Option<SafeTensorWithRouting<'a>>, pub dtype: DType, pub device: Device, } -impl<'a> VarBuilder<'a> { - pub fn from_safetensors( - safetensors: Vec<SafeTensors<'a>>, - dtype: DType, - device: &Device, - ) -> Self { +impl<'a> TensorData<'a> { + fn from_safetensors(safetensors: Vec<SafeTensors<'a>>, dtype: DType, device: &Device) -> Self { let mut routing = HashMap::new(); for (index, sf) in safetensors.iter().enumerate() { for k in sf.names() { routing.insert(k.to_string(), index); } } + let safetensors = SafeTensorWithRouting { + routing, + safetensors, + }; Self { - safetensors: Some((routing, safetensors)), + safetensors: Some(safetensors), device: device.clone(), dtype, } } - pub fn zeros(dtype: DType, device: Device) -> Self { + fn zeros(dtype: DType, device: &Device) -> Self { Self { safetensors: None, - device, + device: device.clone(), dtype, } } +} + +#[derive(Clone)] +pub struct VarBuilder<'a> { + data: Arc<TensorData<'a>>, + path: Vec<String>, +} + +impl<'a> VarBuilder<'a> { + /// Create a `VarBuilder` accessing data frome the safetensors storage. The initial path is + /// set to the root path and sub-paths can be created via the `push_prefix` method. + pub fn from_safetensors(st: Vec<SafeTensors<'a>>, dtype: DType, device: &Device) -> Self { + let data = TensorData::from_safetensors(st, dtype, device); + Self { + data: Arc::new(data), + path: vec![], + } + } + + pub fn zeros(dtype: DType, device: &Device) -> Self { + let data = TensorData::zeros(dtype, device); + Self { + data: Arc::new(data), + path: vec![], + } + } + + pub fn push_prefix(&self, s: &str) -> Self { + let mut path = self.path.clone(); + path.push(s.to_string()); + Self { + data: self.data.clone(), + path, + } + } + /// Short alias for `push_prefix`. + pub fn pp(&self, s: &str) -> Self { + self.push_prefix(s) + } + + pub fn device(&self) -> &Device { + &self.data.device + } + + pub fn dtype(&self) -> DType { + self.data.dtype + } +} + +impl<'a> VarBuilder<'a> { pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> candle::Result<Tensor> { + let data = self.data.as_ref(); let s: Shape = s.into(); - match &self.safetensors { - None => Tensor::zeros(s, self.dtype, &self.device), - Some((routing, safetensors)) => { + match &self.data.safetensors { + None => Tensor::zeros(s, data.dtype, &data.device), + Some(SafeTensorWithRouting { + routing, + safetensors, + }) => { + let path = if self.path.is_empty() { + tensor_name.to_string() + } else { + [&self.path.join("."), tensor_name].join(".") + }; // Unwrap or 0 just to let the proper error flow. - let index = routing.get(tensor_name).unwrap_or(&0); + let index = routing.get(&path).unwrap_or(&0); let tensor = safetensors[*index] - .tensor(tensor_name, &self.device)? - .to_dtype(self.dtype)?; + .tensor(&path, &data.device)? + .to_dtype(data.dtype)?; if *tensor.shape() != s { - let msg = format!("shape mismatch for {tensor_name}"); Err(candle::Error::UnexpectedShape { - msg, + msg: format!("shape mismatch for {path}"), expected: s, got: tensor.shape().clone(), })? |