summaryrefslogtreecommitdiff
path: root/candle-nn/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-11 20:22:34 +0100
committerGitHub <noreply@github.com>2023-07-11 20:22:34 +0100
commitfa760759e5fa94c8486566af6dd3a456d0548221 (patch)
tree257cb5c9aab3f1dffcdf5049b0ae850312a12fd6 /candle-nn/src
parent37cad858698e519435c916421cc97b4f6b7fe53e (diff)
downloadcandle-fa760759e5fa94c8486566af6dd3a456d0548221.tar.gz
candle-fa760759e5fa94c8486566af6dd3a456d0548221.tar.bz2
candle-fa760759e5fa94c8486566af6dd3a456d0548221.zip
Allow for lazy loading of npz files, use it in llama to reduce memory usage in the cpu version. (#141)
Diffstat (limited to 'candle-nn/src')
-rw-r--r--candle-nn/src/var_builder.rs29
1 files changed, 27 insertions, 2 deletions
diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs
index 6d79bddd..7f68ae08 100644
--- a/candle-nn/src/var_builder.rs
+++ b/candle-nn/src/var_builder.rs
@@ -1,4 +1,4 @@
-use candle::{safetensors::SafeTensors, DType, Device, Error, Shape, Tensor};
+use candle::{safetensors::SafeTensors, DType, Device, Error, Result, Shape, Tensor};
use std::collections::HashMap;
use std::sync::Arc;
@@ -9,6 +9,7 @@ enum Tensors<'a> {
routing: HashMap<String, usize>,
safetensors: Vec<SafeTensors<'a>>,
},
+ Npz(candle::npy::NpzTensors),
TensorMap(HashMap<String, Tensor>),
Zeros,
}
@@ -53,6 +54,15 @@ impl<'a> TensorData<'a> {
dtype,
}
}
+
+ fn from_npz<P: AsRef<std::path::Path>>(file: P, dtype: DType, device: &Device) -> Result<Self> {
+ let npz = candle::npy::NpzTensors::new(file)?;
+ Ok(Self {
+ tensors: Tensors::Npz(npz),
+ device: device.clone(),
+ dtype,
+ })
+ }
}
#[derive(Clone)]
@@ -88,6 +98,18 @@ impl<'a> VarBuilder<'a> {
}
}
+ pub fn from_npz<P: AsRef<std::path::Path>>(
+ file: P,
+ dtype: DType,
+ device: &Device,
+ ) -> Result<Self> {
+ let data = TensorData::from_npz(file, dtype, device)?;
+ Ok(Self {
+ data: Arc::new(data),
+ path: vec![],
+ })
+ }
+
pub fn push_prefix(&self, s: &str) -> Self {
let mut path = self.path.clone();
path.push(s.to_string());
@@ -112,7 +134,7 @@ impl<'a> VarBuilder<'a> {
}
impl<'a> VarBuilder<'a> {
- pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> candle::Result<Tensor> {
+ pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> Result<Tensor> {
let data = self.data.as_ref();
let s: Shape = s.into();
let path = if self.path.is_empty() {
@@ -128,6 +150,9 @@ impl<'a> VarBuilder<'a> {
path: path.to_string(),
})?
.clone(),
+ Tensors::Npz(npz) => npz.get(&path)?.ok_or_else(|| Error::CannotFindTensor {
+ path: path.to_string(),
+ })?,
Tensors::SafeTensorWithRouting {
routing,
safetensors,