diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-11 20:22:34 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-11 20:22:34 +0100 |
commit | fa760759e5fa94c8486566af6dd3a456d0548221 (patch) | |
tree | 257cb5c9aab3f1dffcdf5049b0ae850312a12fd6 /candle-nn/src | |
parent | 37cad858698e519435c916421cc97b4f6b7fe53e (diff) | |
download | candle-fa760759e5fa94c8486566af6dd3a456d0548221.tar.gz candle-fa760759e5fa94c8486566af6dd3a456d0548221.tar.bz2 candle-fa760759e5fa94c8486566af6dd3a456d0548221.zip |
Allow for lazy loading of npz files, use it in llama to reduce memory usage in the cpu version. (#141)
Diffstat (limited to 'candle-nn/src')
-rw-r--r-- | candle-nn/src/var_builder.rs | 29 |
1 files changed, 27 insertions, 2 deletions
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 6d79bddd..7f68ae08 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -1,4 +1,4 @@ -use candle::{safetensors::SafeTensors, DType, Device, Error, Shape, Tensor}; +use candle::{safetensors::SafeTensors, DType, Device, Error, Result, Shape, Tensor}; use std::collections::HashMap; use std::sync::Arc; @@ -9,6 +9,7 @@ enum Tensors<'a> { routing: HashMap<String, usize>, safetensors: Vec<SafeTensors<'a>>, }, + Npz(candle::npy::NpzTensors), TensorMap(HashMap<String, Tensor>), Zeros, } @@ -53,6 +54,15 @@ impl<'a> TensorData<'a> { dtype, } } + + fn from_npz<P: AsRef<std::path::Path>>(file: P, dtype: DType, device: &Device) -> Result<Self> { + let npz = candle::npy::NpzTensors::new(file)?; + Ok(Self { + tensors: Tensors::Npz(npz), + device: device.clone(), + dtype, + }) + } } #[derive(Clone)] @@ -88,6 +98,18 @@ impl<'a> VarBuilder<'a> { } } + pub fn from_npz<P: AsRef<std::path::Path>>( + file: P, + dtype: DType, + device: &Device, + ) -> Result<Self> { + let data = TensorData::from_npz(file, dtype, device)?; + Ok(Self { + data: Arc::new(data), + path: vec![], + }) + } + pub fn push_prefix(&self, s: &str) -> Self { let mut path = self.path.clone(); path.push(s.to_string()); @@ -112,7 +134,7 @@ impl<'a> VarBuilder<'a> { } impl<'a> VarBuilder<'a> { - pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> candle::Result<Tensor> { + pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> Result<Tensor> { let data = self.data.as_ref(); let s: Shape = s.into(); let path = if self.path.is_empty() { @@ -128,6 +150,9 @@ impl<'a> VarBuilder<'a> { path: path.to_string(), })? .clone(), + Tensors::Npz(npz) => npz.get(&path)?.ok_or_else(|| Error::CannotFindTensor { + path: path.to_string(), + })?, Tensors::SafeTensorWithRouting { routing, safetensors, |