summaryrefslogtreecommitdiff
path: root/candle-nn/src
diff options
context:
space:
mode:
authorHarry Stern <boustrophedon@users.noreply.github.com>2024-05-12 01:26:06 -0400
committerGitHub <noreply@github.com>2024-05-12 07:26:06 +0200
commit13c64f6828360a9cb9b58b4f817e4f3b8316388c (patch)
tree0fc7111817c8d73dd9ec16ce56029a51bd37c8f3 /candle-nn/src
parent21f82a5155818214070b8b50e45e0148bd13833e (diff)
downloadcandle-13c64f6828360a9cb9b58b4f817e4f3b8316388c.tar.gz
candle-13c64f6828360a9cb9b58b4f817e4f3b8316388c.tar.bz2
candle-13c64f6828360a9cb9b58b4f817e4f3b8316388c.zip
Fix VarBuilder::from_slice_safetensors (#2180)
Also implement SimpleBackend for SliceSafetensors Signed-off-by: Harry Stern <harry@harrystern.net>
Diffstat (limited to 'candle-nn/src')
-rw-r--r--candle-nn/src/var_builder.rs34
1 files changed, 30 insertions, 4 deletions
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs
index ebbc9084..d6f6214f 100644
--- a/candle-nn/src/var_builder.rs
+++ b/candle-nn/src/var_builder.rs
@@ -422,6 +422,32 @@ impl SimpleBackend for candle::safetensors::BufferedSafetensors {
}
}
+impl<'a> SimpleBackend for candle::safetensors::SliceSafetensors<'a> {
+ fn get(
+ &self,
+ s: Shape,
+ name: &str,
+ _: crate::Init,
+ dtype: DType,
+ dev: &Device,
+ ) -> Result<Tensor> {
+ let tensor = self.load(name, dev)?.to_dtype(dtype)?;
+ if tensor.shape() != &s {
+ Err(candle::Error::UnexpectedShape {
+ msg: format!("shape mismatch for {name}"),
+ expected: s,
+ got: tensor.shape().clone(),
+ }
+ .bt())?
+ }
+ Ok(tensor)
+ }
+
+ fn contains_tensor(&self, name: &str) -> bool {
+ self.get(name).is_ok()
+ }
+}
+
impl<'a> VarBuilder<'a> {
/// Initializes a `VarBuilder` using a custom backend.
///
@@ -481,15 +507,15 @@ impl<'a> VarBuilder<'a> {
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
}
- /// Initializes a `VarBuilder` from a binary builder in the safetensor format.
+ /// Initializes a `VarBuilder` from a binary buffer in the safetensor format.
pub fn from_buffered_safetensors(data: Vec<u8>, dtype: DType, dev: &Device) -> Result<Self> {
let tensors = candle::safetensors::BufferedSafetensors::new(data)?;
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
}
- /// Initializes a `VarBuilder` from a binary builder in the safetensor format.
- pub fn from_slice_safetensors(data: Vec<u8>, dtype: DType, dev: &Device) -> Result<Self> {
- let tensors = candle::safetensors::BufferedSafetensors::new(data)?;
+ /// Initializes a `VarBuilder` from a binary slice in the safetensor format.
+ pub fn from_slice_safetensors(data: &'a [u8], dtype: DType, dev: &Device) -> Result<Self> {
+ let tensors = candle::safetensors::SliceSafetensors::new(data)?;
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
}