summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Cargo.toml5
-rw-r--r--candle-core/Cargo.toml1
-rw-r--r--candle-core/examples/tensor-tools.rs3
-rw-r--r--candle-core/src/safetensors.rs94
-rw-r--r--candle-nn/src/var_map.rs16
5 files changed, 104 insertions, 15 deletions
diff --git a/Cargo.toml b/Cargo.toml
index 5ae64523..6db2a326 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -41,9 +41,10 @@ imageproc = { version = "0.23.0", default-features = false }
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
libc = { version = "0.2.147" }
log = "0.4"
-memmap2 = "0.7.1"
+memmap2 = { version = "0.7.1", features = ["stable_deref_trait"] }
num_cpus = "1.15.0"
num-traits = "0.2.15"
+parquet = { version = "45.0.0" }
rand = "0.8.5"
rand_distr = "0.4.3"
rayon = "1.7.0"
@@ -57,8 +58,8 @@ tracing = "0.1.37"
tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7"
wav = "1.0.0"
+yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "0.6.6", default-features = false }
-parquet = { version = "45.0.0" }
[profile.release-with-debug]
inherits = "release"
diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml
index 7af9b6fa..0b6fce50 100644
--- a/candle-core/Cargo.toml
+++ b/candle-core/Cargo.toml
@@ -26,6 +26,7 @@ rand_distr = { workspace = true }
rayon = { workspace = true }
safetensors = { workspace = true }
thiserror = { workspace = true }
+yoke = { workspace = true }
zip = { workspace = true }
[dev-dependencies]
diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs
index c0d5a334..3982f2c3 100644
--- a/candle-core/examples/tensor-tools.rs
+++ b/candle-core/examples/tensor-tools.rs
@@ -150,8 +150,7 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
}
}
Format::Safetensors => {
- let tensors = unsafe { candle_core::safetensors::MmapedFile::new(file)? };
- let tensors = tensors.deserialize()?;
+ let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::new(file)? };
let mut tensors = tensors.tensors();
tensors.sort_by(|a, b| a.0.cmp(&b.0));
for (name, view) in tensors.iter() {
diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs
index d588ea67..7e23c582 100644
--- a/candle-core/src/safetensors.rs
+++ b/candle-core/src/safetensors.rs
@@ -251,6 +251,100 @@ pub fn save<K: AsRef<str> + Ord + std::fmt::Display, P: AsRef<Path>>(
Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?)
}
+#[derive(yoke::Yokeable)]
+struct SafeTensors_<'a>(SafeTensors<'a>);
+
+pub struct MmapedSafetensors {
+ safetensors: Vec<yoke::Yoke<SafeTensors_<'static>, memmap2::Mmap>>,
+ routing: Option<HashMap<String, usize>>,
+}
+
+impl MmapedSafetensors {
+ /// Creates a wrapper around a memory mapped file and deserialize the safetensors header.
+ ///
+ /// # Safety
+ ///
+ /// The unsafe is inherited from [`memmap2::MmapOptions`].
+ pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
+ let p = p.as_ref();
+ let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
+ let file = memmap2::MmapOptions::new()
+ .map(&file)
+ .map_err(|e| Error::from(e).with_path(p))?;
+ let safetensors = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
+ file,
+ |data: &[u8]| {
+ let st = safetensors::SafeTensors::deserialize(data)
+ .map_err(|e| Error::from(e).with_path(p))?;
+ Ok::<_, Error>(SafeTensors_(st))
+ },
+ )?;
+ Ok(Self {
+ safetensors: vec![safetensors],
+ routing: None,
+ })
+ }
+
+ /// Creates a wrapper around multiple memory mapped file and deserialize the safetensors headers.
+ ///
+ /// If a tensor name appears in multiple files, the last entry is returned.
+ ///
+ /// # Safety
+ ///
+ /// The unsafe is inherited from [`memmap2::MmapOptions`].
+ pub unsafe fn multi<P: AsRef<Path>>(paths: &[P]) -> Result<Self> {
+ let mut routing = HashMap::new();
+ let mut safetensors = vec![];
+ for (index, p) in paths.iter().enumerate() {
+ let p = p.as_ref();
+ let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
+ let file = memmap2::MmapOptions::new()
+ .map(&file)
+ .map_err(|e| Error::from(e).with_path(p))?;
+ let data = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
+ file,
+ |data: &[u8]| {
+ let st = safetensors::SafeTensors::deserialize(data)
+ .map_err(|e| Error::from(e).with_path(p))?;
+ Ok::<_, Error>(SafeTensors_(st))
+ },
+ )?;
+ for k in data.get().0.names() {
+ routing.insert(k.to_string(), index);
+ }
+ safetensors.push(data)
+ }
+ Ok(Self {
+ safetensors,
+ routing: Some(routing),
+ })
+ }
+
+ pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
+ let index = match &self.routing {
+ None => 0,
+ Some(routing) => {
+ let index = routing.get(name).ok_or_else(|| {
+ Error::CannotFindTensor {
+ path: name.to_string(),
+ }
+ .bt()
+ })?;
+ *index
+ }
+ };
+ self.safetensors[index].get().0.tensor(name)?.load(dev)
+ }
+
+ pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
+ let mut tensors = vec![];
+ for safetensors in self.safetensors.iter() {
+ tensors.push(safetensors.get().0.tensors())
+ }
+ tensors.into_iter().flatten().collect()
+ }
+}
+
pub struct MmapedFile {
path: std::path::PathBuf,
inner: memmap2::Mmap,
diff --git a/candle-nn/src/var_map.rs b/candle-nn/src/var_map.rs
index c17558b7..d3da84e8 100644
--- a/candle-nn/src/var_map.rs
+++ b/candle-nn/src/var_map.rs
@@ -1,4 +1,4 @@
-use candle::{safetensors::Load, DType, Device, Result, Shape, Tensor, Var};
+use candle::{DType, Device, Result, Shape, Tensor, Var};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
@@ -40,18 +40,12 @@ impl VarMap {
/// Note that values for variables that are currently not in the map are not kept.
pub fn load<P: AsRef<std::path::Path>>(&mut self, path: P) -> Result<()> {
let path = path.as_ref();
- let data = unsafe { candle::safetensors::MmapedFile::new(path)? };
- let data = data.deserialize()?;
+ let data = unsafe { candle::safetensors::MmapedSafetensors::new(path)? };
let mut tensor_data = self.data.lock().unwrap();
for (name, var) in tensor_data.iter_mut() {
- match data.tensor(name) {
- Ok(data) => {
- let data: Tensor = data.load(var.device())?;
- if let Err(err) = var.set(&data) {
- candle::bail!("error setting {name} using data from {path:?}: {err}",)
- }
- }
- Err(_) => candle::bail!("cannot find tensor for {name}"),
+ let data = data.load(name, var.device())?;
+ if let Err(err) = var.set(&data) {
+ candle::bail!("error setting {name} using data from {path:?}: {err}",)
}
}
Ok(())