diff options
Diffstat (limited to 'candle-core/examples/llama/weights.rs')
-rw-r--r-- | candle-core/examples/llama/weights.rs | 24 |
1 files changed, 8 insertions, 16 deletions
diff --git a/candle-core/examples/llama/weights.rs b/candle-core/examples/llama/weights.rs index 4ad9b391..cc3fccd4 100644 --- a/candle-core/examples/llama/weights.rs +++ b/candle-core/examples/llama/weights.rs @@ -1,8 +1,5 @@ use super::*; -use candle::{Device, Result, Tensor}; -use memmap2::MmapOptions; -use safetensors::SafeTensors; -use std::fs::File; +use candle::{safetensors::SafeTensors, Device, Result, Tensor}; use std::path::PathBuf; pub struct VarBuilder<'a> { @@ -30,8 +27,9 @@ impl<'a> VarBuilder<'a> { pub fn get(&self, tensor_name: &str) -> Result<Tensor> { // Unwrap or 0 just to let the proper error flow. let index = self.routing.get(tensor_name).unwrap_or(&0); - let view = self.safetensors[*index].tensor(tensor_name).unwrap(); - candle::safetensors::convert(view, &self.device)?.to_dtype(DTYPE) + self.safetensors[*index] + .tensor(tensor_name, &self.device)? + .to_dtype(DTYPE) } } @@ -107,18 +105,12 @@ impl Llama { ) -> Result<Self> { let handles: Vec<_> = filenames .iter() - .map(|f| { - let file = File::open(f).unwrap(); - unsafe { MmapOptions::new().map(&file).unwrap() } - }) - .collect(); + .map(candle::safetensors::MmapedFile::new) + .collect::<Result<Vec<_>>>()?; let tensors: Vec<_> = handles .iter() - .map(|h| { - let tensors = SafeTensors::deserialize(h).unwrap(); - tensors - }) - .collect(); + .map(|h| h.deserialize()) + .collect::<Result<Vec<_>>>()?; let vb = VarBuilder::new(tensors, device.clone()); |