diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-05-11 13:15:42 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-11 13:15:42 +0200 |
commit | 21f82a5155818214070b8b50e45e0148bd13833e (patch) | |
tree | 60b4e88684b744e3d4a5c4099607a4786d48f12c /candle-nn | |
parent | 9cff7bc3f48bb8836d0f12b22b00baacd7a5c9bb (diff) | |
download | candle-21f82a5155818214070b8b50e45e0148bd13833e.tar.gz candle-21f82a5155818214070b8b50e45e0148bd13833e.tar.bz2 candle-21f82a5155818214070b8b50e45e0148bd13833e.zip |
Add SliceSafetensors. (#2179)
* Add SlicedSafetensors.
* And add some testing.
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/src/var_builder.rs | 6 |
1 files changed, 6 insertions, 0 deletions
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 68bd6f05..ebbc9084 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -487,6 +487,12 @@ impl<'a> VarBuilder<'a> { Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone())) } + /// Initializes a `VarBuilder` from a binary builder in the safetensor format. + pub fn from_slice_safetensors(data: Vec<u8>, dtype: DType, dev: &Device) -> Result<Self> { + let tensors = candle::safetensors::BufferedSafetensors::new(data)?; + Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone())) + } + /// Initializes a `VarBuilder` that retrieves tensors stored in a numpy npz file. pub fn from_npz<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> { let npz = candle::npy::NpzTensors::new(p)?; |