diff options
author | Harry Stern <boustrophedon@users.noreply.github.com> | 2024-05-12 01:26:06 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-12 07:26:06 +0200 |
commit | 13c64f6828360a9cb9b58b4f817e4f3b8316388c (patch) | |
tree | 0fc7111817c8d73dd9ec16ce56029a51bd37c8f3 /candle-nn/src | |
parent | 21f82a5155818214070b8b50e45e0148bd13833e (diff) | |
download | candle-13c64f6828360a9cb9b58b4f817e4f3b8316388c.tar.gz candle-13c64f6828360a9cb9b58b4f817e4f3b8316388c.tar.bz2 candle-13c64f6828360a9cb9b58b4f817e4f3b8316388c.zip |
Fix VarBuilder::from_slice_safetensors (#2180)
Also implement SimpleBackend for SliceSafetensors
Signed-off-by: Harry Stern <harry@harrystern.net>
Diffstat (limited to 'candle-nn/src')
-rw-r--r-- | candle-nn/src/var_builder.rs | 34 |
1 files changed, 30 insertions, 4 deletions
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index ebbc9084..d6f6214f 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -422,6 +422,32 @@ impl SimpleBackend for candle::safetensors::BufferedSafetensors { } } +impl<'a> SimpleBackend for candle::safetensors::SliceSafetensors<'a> { + fn get( + &self, + s: Shape, + name: &str, + _: crate::Init, + dtype: DType, + dev: &Device, + ) -> Result<Tensor> { + let tensor = self.load(name, dev)?.to_dtype(dtype)?; + if tensor.shape() != &s { + Err(candle::Error::UnexpectedShape { + msg: format!("shape mismatch for {name}"), + expected: s, + got: tensor.shape().clone(), + } + .bt())? + } + Ok(tensor) + } + + fn contains_tensor(&self, name: &str) -> bool { + self.get(name).is_ok() + } +} + impl<'a> VarBuilder<'a> { /// Initializes a `VarBuilder` using a custom backend. /// @@ -481,15 +507,15 @@ impl<'a> VarBuilder<'a> { Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone())) } - /// Initializes a `VarBuilder` from a binary builder in the safetensor format. + /// Initializes a `VarBuilder` from a binary buffer in the safetensor format. pub fn from_buffered_safetensors(data: Vec<u8>, dtype: DType, dev: &Device) -> Result<Self> { let tensors = candle::safetensors::BufferedSafetensors::new(data)?; Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone())) } - /// Initializes a `VarBuilder` from a binary builder in the safetensor format. - pub fn from_slice_safetensors(data: Vec<u8>, dtype: DType, dev: &Device) -> Result<Self> { - let tensors = candle::safetensors::BufferedSafetensors::new(data)?; + /// Initializes a `VarBuilder` from a binary slice in the safetensor format. + pub fn from_slice_safetensors(data: &'a [u8], dtype: DType, dev: &Device) -> Result<Self> { + let tensors = candle::safetensors::SliceSafetensors::new(data)?; Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone())) } |