diff options
-rw-r--r-- | candle-core/src/safetensors.rs | 22 | ||||
-rw-r--r-- | candle-examples/examples/t5/main.rs | 26 | ||||
-rw-r--r-- | candle-nn/src/var_builder.rs | 43 |
3 files changed, 61 insertions, 30 deletions
diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index 7e23c582..12df7fbe 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -321,6 +321,18 @@ impl MmapedSafetensors { } pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> { + self.get(name)?.load(dev) + } + + pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> { + let mut tensors = vec![]; + for safetensors in self.safetensors.iter() { + tensors.push(safetensors.get().0.tensors()) + } + tensors.into_iter().flatten().collect() + } + + pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> { let index = match &self.routing { None => 0, Some(routing) => { @@ -333,15 +345,7 @@ impl MmapedSafetensors { *index } }; - self.safetensors[index].get().0.tensor(name)?.load(dev) - } - - pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> { - let mut tensors = vec![]; - for safetensors in self.safetensors.iter() { - tensors.push(safetensors.get().0.tensors()) - } - tensors.into_iter().flatten().collect() + Ok(self.safetensors[index].get().0.tensor(name)?) } } diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs index 55929c33..71106497 100644 --- a/candle-examples/examples/t5/main.rs +++ b/candle-examples/examples/t5/main.rs @@ -122,30 +122,16 @@ impl T5ModelBuilder { } pub fn build_encoder(&self) -> Result<t5::T5EncoderModel> { - let weights = self - .weights_filename - .iter() - .map(|f| unsafe { candle::safetensors::MmapedFile::new(f) }) - .collect::<candle::Result<Vec<_>>>()?; - let weights = weights - .iter() - .map(|w| w.deserialize()) - .collect::<candle::Result<Vec<_>>>()?; - let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device); + let vb = unsafe { + VarBuilder::from_mmaped_safetensors(&self.weights_filename, DTYPE, &self.device)? + }; Ok(t5::T5EncoderModel::load(vb, &self.config)?) } pub fn build_conditional_generation(&self) -> Result<t5::T5ForConditionalGeneration> { - let weights = self - .weights_filename - .iter() - .map(|f| unsafe { candle::safetensors::MmapedFile::new(f) }) - .collect::<candle::Result<Vec<_>>>()?; - let weights = weights - .iter() - .map(|w| w.deserialize()) - .collect::<candle::Result<Vec<_>>>()?; - let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device); + let vb = unsafe { + VarBuilder::from_mmaped_safetensors(&self.weights_filename, DTYPE, &self.device)? + }; Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?) } } diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 4ccbaf17..7b733e0c 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -325,6 +325,32 @@ impl SimpleBackend for candle::npy::NpzTensors { } } +impl SimpleBackend for candle::safetensors::MmapedSafetensors { + fn get( + &self, + s: Shape, + name: &str, + _: crate::Init, + dtype: DType, + dev: &Device, + ) -> Result<Tensor> { + let tensor = self.load(name, dev)?.to_dtype(dtype)?; + if tensor.shape() != &s { + Err(candle::Error::UnexpectedShape { + msg: format!("shape mismatch for {name}"), + expected: s, + got: tensor.shape().clone(), + } + .bt())? + } + Ok(tensor) + } + + fn contains_tensor(&self, name: &str) -> bool { + self.get(name).is_ok() + } +} + impl<'a> VarBuilder<'a> { fn new(backend: Box<dyn SimpleBackend + 'a>, dtype: DType, device: Device) -> Self { let data = TensorData { @@ -361,7 +387,7 @@ impl<'a> VarBuilder<'a> { } /// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors - /// files. + /// data. pub fn from_safetensors(safetensors: Vec<SafeTensors<'a>>, dtype: DType, dev: &Device) -> Self { let mut routing = HashMap::new(); for (index, sf) in safetensors.iter().enumerate() { @@ -376,6 +402,21 @@ impl<'a> VarBuilder<'a> { Self::new(Box::new(tensors), dtype, dev.clone()) } + /// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors + /// files. + /// + /// # Safety + /// + /// The unsafe is inherited from [`memmap2::MmapOptions`]. + pub unsafe fn from_mmaped_safetensors<P: AsRef<std::path::Path>>( + paths: &[P], + dtype: DType, + dev: &Device, + ) -> Result<Self> { + let tensors = candle::safetensors::MmapedSafetensors::multi(paths)?; + Ok(Self::new(Box::new(tensors), dtype, dev.clone())) + } + /// Initializes a `VarBuilder` that retrieves tensors stored in a numpy npz file. pub fn from_npz<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> { let npz = candle::npy::NpzTensors::new(p)?; |