summaryrefslogtreecommitdiff
path: root/candle-core/examples/llama/weights.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-03 10:19:57 +0100
committerGitHub <noreply@github.com>2023-07-03 10:19:57 +0100
commitec4871b8a41615546ea8b42638c56ce64a9fdf72 (patch)
treee1eb41588bd349d66e0d14f55a539a027899e3fa /candle-core/examples/llama/weights.rs
parentb036faf6a0ab163b926a0178061aee9f9cf8034f (diff)
parent899c76de7567572f522f8711e10150de8e4e0d6f (diff)
downloadcandle-ec4871b8a41615546ea8b42638c56ce64a9fdf72.tar.gz
candle-ec4871b8a41615546ea8b42638c56ce64a9fdf72.tar.bz2
candle-ec4871b8a41615546ea8b42638c56ce64a9fdf72.zip
Merge pull request #57 from LaurentMazare/safetensor-module2
Move more safetensors bits to the shared module.
Diffstat (limited to 'candle-core/examples/llama/weights.rs')
-rw-r--r--candle-core/examples/llama/weights.rs24
1 files changed, 8 insertions, 16 deletions
diff --git a/candle-core/examples/llama/weights.rs b/candle-core/examples/llama/weights.rs
index 4ad9b391..cc3fccd4 100644
--- a/candle-core/examples/llama/weights.rs
+++ b/candle-core/examples/llama/weights.rs
@@ -1,8 +1,5 @@
use super::*;
-use candle::{Device, Result, Tensor};
-use memmap2::MmapOptions;
-use safetensors::SafeTensors;
-use std::fs::File;
+use candle::{safetensors::SafeTensors, Device, Result, Tensor};
use std::path::PathBuf;
pub struct VarBuilder<'a> {
@@ -30,8 +27,9 @@ impl<'a> VarBuilder<'a> {
pub fn get(&self, tensor_name: &str) -> Result<Tensor> {
// Unwrap or 0 just to let the proper error flow.
let index = self.routing.get(tensor_name).unwrap_or(&0);
- let view = self.safetensors[*index].tensor(tensor_name).unwrap();
- candle::safetensors::convert(view, &self.device)?.to_dtype(DTYPE)
+ self.safetensors[*index]
+ .tensor(tensor_name, &self.device)?
+ .to_dtype(DTYPE)
}
}
@@ -107,18 +105,12 @@ impl Llama {
) -> Result<Self> {
let handles: Vec<_> = filenames
.iter()
- .map(|f| {
- let file = File::open(f).unwrap();
- unsafe { MmapOptions::new().map(&file).unwrap() }
- })
- .collect();
+ .map(candle::safetensors::MmapedFile::new)
+ .collect::<Result<Vec<_>>>()?;
let tensors: Vec<_> = handles
.iter()
- .map(|h| {
- let tensors = SafeTensors::deserialize(h).unwrap();
- tensors
- })
- .collect();
+ .map(|h| h.deserialize())
+ .collect::<Result<Vec<_>>>()?;
let vb = VarBuilder::new(tensors, device.clone());