diff options
Diffstat (limited to 'candle-core')
-rw-r--r-- | candle-core/src/safetensors.rs | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index 1880a041..132fb914 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -242,7 +242,11 @@ fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> { pub fn load<P: AsRef<Path>>(filename: P, device: &Device) -> Result<HashMap<String, Tensor>> { let data = std::fs::read(filename.as_ref())?; - let st = safetensors::SafeTensors::deserialize(&data)?; + load_buffer(&data[..], device) +} + +pub fn load_buffer(data: &[u8], device: &Device) -> Result<HashMap<String, Tensor>> { + let st = safetensors::SafeTensors::deserialize(data)?; st.tensors() .into_iter() .map(|(name, view)| Ok((name, view.load(device)?))) |