From 21f82a5155818214070b8b50e45e0148bd13833e Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 11 May 2024 13:15:42 +0200 Subject: Add SliceSafetensors. (#2179) * Add SlicedSafetensors. * And add some testing. --- candle-nn/src/var_builder.rs | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'candle-nn') diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 68bd6f05..ebbc9084 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -487,6 +487,12 @@ impl<'a> VarBuilder<'a> { Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone())) } + /// Initializes a `VarBuilder` from a binary builder in the safetensor format. + pub fn from_slice_safetensors(data: Vec, dtype: DType, dev: &Device) -> Result { + let tensors = candle::safetensors::BufferedSafetensors::new(data)?; + Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone())) + } + /// Initializes a `VarBuilder` that retrieves tensors stored in a numpy npz file. pub fn from_npz>(p: P, dtype: DType, dev: &Device) -> Result { let npz = candle::npy::NpzTensors::new(p)?; -- cgit v1.2.3