summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/safetensors.rs22
-rw-r--r--candle-examples/examples/t5/main.rs26
-rw-r--r--candle-nn/src/var_builder.rs43
3 files changed, 61 insertions, 30 deletions
diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs
index 7e23c582..12df7fbe 100644
--- a/candle-core/src/safetensors.rs
+++ b/candle-core/src/safetensors.rs
@@ -321,6 +321,18 @@ impl MmapedSafetensors {
}
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
+ self.get(name)?.load(dev)
+ }
+
+ pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
+ let mut tensors = vec![];
+ for safetensors in self.safetensors.iter() {
+ tensors.push(safetensors.get().0.tensors())
+ }
+ tensors.into_iter().flatten().collect()
+ }
+
+ pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
let index = match &self.routing {
None => 0,
Some(routing) => {
@@ -333,15 +345,7 @@ impl MmapedSafetensors {
*index
}
};
- self.safetensors[index].get().0.tensor(name)?.load(dev)
- }
-
- pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
- let mut tensors = vec![];
- for safetensors in self.safetensors.iter() {
- tensors.push(safetensors.get().0.tensors())
- }
- tensors.into_iter().flatten().collect()
+ Ok(self.safetensors[index].get().0.tensor(name)?)
}
}
diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs
index 55929c33..71106497 100644
--- a/candle-examples/examples/t5/main.rs
+++ b/candle-examples/examples/t5/main.rs
@@ -122,30 +122,16 @@ impl T5ModelBuilder {
}
pub fn build_encoder(&self) -> Result<t5::T5EncoderModel> {
- let weights = self
- .weights_filename
- .iter()
- .map(|f| unsafe { candle::safetensors::MmapedFile::new(f) })
- .collect::<candle::Result<Vec<_>>>()?;
- let weights = weights
- .iter()
- .map(|w| w.deserialize())
- .collect::<candle::Result<Vec<_>>>()?;
- let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device);
+ let vb = unsafe {
+ VarBuilder::from_mmaped_safetensors(&self.weights_filename, DTYPE, &self.device)?
+ };
Ok(t5::T5EncoderModel::load(vb, &self.config)?)
}
pub fn build_conditional_generation(&self) -> Result<t5::T5ForConditionalGeneration> {
- let weights = self
- .weights_filename
- .iter()
- .map(|f| unsafe { candle::safetensors::MmapedFile::new(f) })
- .collect::<candle::Result<Vec<_>>>()?;
- let weights = weights
- .iter()
- .map(|w| w.deserialize())
- .collect::<candle::Result<Vec<_>>>()?;
- let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device);
+ let vb = unsafe {
+ VarBuilder::from_mmaped_safetensors(&self.weights_filename, DTYPE, &self.device)?
+ };
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
}
}
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs
index 4ccbaf17..7b733e0c 100644
--- a/candle-nn/src/var_builder.rs
+++ b/candle-nn/src/var_builder.rs
@@ -325,6 +325,32 @@ impl SimpleBackend for candle::npy::NpzTensors {
}
}
+impl SimpleBackend for candle::safetensors::MmapedSafetensors {
+ fn get(
+ &self,
+ s: Shape,
+ name: &str,
+ _: crate::Init,
+ dtype: DType,
+ dev: &Device,
+ ) -> Result<Tensor> {
+ let tensor = self.load(name, dev)?.to_dtype(dtype)?;
+ if tensor.shape() != &s {
+ Err(candle::Error::UnexpectedShape {
+ msg: format!("shape mismatch for {name}"),
+ expected: s,
+ got: tensor.shape().clone(),
+ }
+ .bt())?
+ }
+ Ok(tensor)
+ }
+
+ fn contains_tensor(&self, name: &str) -> bool {
+ self.get(name).is_ok()
+ }
+}
+
impl<'a> VarBuilder<'a> {
fn new(backend: Box<dyn SimpleBackend + 'a>, dtype: DType, device: Device) -> Self {
let data = TensorData {
@@ -361,7 +387,7 @@ impl<'a> VarBuilder<'a> {
}
/// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors
- /// files.
+ /// data.
pub fn from_safetensors(safetensors: Vec<SafeTensors<'a>>, dtype: DType, dev: &Device) -> Self {
let mut routing = HashMap::new();
for (index, sf) in safetensors.iter().enumerate() {
@@ -376,6 +402,21 @@ impl<'a> VarBuilder<'a> {
Self::new(Box::new(tensors), dtype, dev.clone())
}
+ /// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors
+ /// files.
+ ///
+ /// # Safety
+ ///
+ /// The unsafe is inherited from [`memmap2::MmapOptions`].
+ pub unsafe fn from_mmaped_safetensors<P: AsRef<std::path::Path>>(
+ paths: &[P],
+ dtype: DType,
+ dev: &Device,
+ ) -> Result<Self> {
+ let tensors = candle::safetensors::MmapedSafetensors::multi(paths)?;
+ Ok(Self::new(Box::new(tensors), dtype, dev.clone()))
+ }
+
/// Initializes a `VarBuilder` that retrieves tensors stored in a numpy npz file.
pub fn from_npz<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
let npz = candle::npy::NpzTensors::new(p)?;