summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-05-11 13:15:42 +0200
committerGitHub <noreply@github.com>2024-05-11 13:15:42 +0200
commit21f82a5155818214070b8b50e45e0148bd13833e (patch)
tree60b4e88684b744e3d4a5c4099607a4786d48f12c /candle-nn
parent9cff7bc3f48bb8836d0f12b22b00baacd7a5c9bb (diff)
downloadcandle-21f82a5155818214070b8b50e45e0148bd13833e.tar.gz
candle-21f82a5155818214070b8b50e45e0148bd13833e.tar.bz2
candle-21f82a5155818214070b8b50e45e0148bd13833e.zip
Add SliceSafetensors. (#2179)
* Add SlicedSafetensors. * And add some testing.
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/var_builder.rs6
1 files changed, 6 insertions, 0 deletions
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs
index 68bd6f05..ebbc9084 100644
--- a/candle-nn/src/var_builder.rs
+++ b/candle-nn/src/var_builder.rs
@@ -487,6 +487,12 @@ impl<'a> VarBuilder<'a> {
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
}
+ /// Initializes a `VarBuilder` from a binary builder in the safetensor format.
+ pub fn from_slice_safetensors(data: Vec<u8>, dtype: DType, dev: &Device) -> Result<Self> {
+ let tensors = candle::safetensors::BufferedSafetensors::new(data)?;
+ Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
+ }
+
/// Initializes a `VarBuilder` that retrieves tensors stored in a numpy npz file.
pub fn from_npz<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
let npz = candle::npy::NpzTensors::new(p)?;