diff options
-rw-r--r-- | Cargo.toml | 5 | ||||
-rw-r--r-- | candle-core/Cargo.toml | 1 | ||||
-rw-r--r-- | candle-core/examples/tensor-tools.rs | 3 | ||||
-rw-r--r-- | candle-core/src/safetensors.rs | 94 | ||||
-rw-r--r-- | candle-nn/src/var_map.rs | 16 |
5 files changed, 104 insertions, 15 deletions
@@ -41,9 +41,10 @@ imageproc = { version = "0.23.0", default-features = false } intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] } libc = { version = "0.2.147" } log = "0.4" -memmap2 = "0.7.1" +memmap2 = { version = "0.7.1", features = ["stable_deref_trait"] } num_cpus = "1.15.0" num-traits = "0.2.15" +parquet = { version = "45.0.0" } rand = "0.8.5" rand_distr = "0.4.3" rayon = "1.7.0" @@ -57,8 +58,8 @@ tracing = "0.1.37" tracing-chrome = "0.7.1" tracing-subscriber = "0.3.7" wav = "1.0.0" +yoke = { version = "0.7.2", features = ["derive"] } zip = { version = "0.6.6", default-features = false } -parquet = { version = "45.0.0" } [profile.release-with-debug] inherits = "release" diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 7af9b6fa..0b6fce50 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -26,6 +26,7 @@ rand_distr = { workspace = true } rayon = { workspace = true } safetensors = { workspace = true } thiserror = { workspace = true } +yoke = { workspace = true } zip = { workspace = true } [dev-dependencies] diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs index c0d5a334..3982f2c3 100644 --- a/candle-core/examples/tensor-tools.rs +++ b/candle-core/examples/tensor-tools.rs @@ -150,8 +150,7 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R } } Format::Safetensors => { - let tensors = unsafe { candle_core::safetensors::MmapedFile::new(file)? }; - let tensors = tensors.deserialize()?; + let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::new(file)? }; let mut tensors = tensors.tensors(); tensors.sort_by(|a, b| a.0.cmp(&b.0)); for (name, view) in tensors.iter() { diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index d588ea67..7e23c582 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -251,6 +251,100 @@ pub fn save<K: AsRef<str> + Ord + std::fmt::Display, P: AsRef<Path>>( Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?) } +#[derive(yoke::Yokeable)] +struct SafeTensors_<'a>(SafeTensors<'a>); + +pub struct MmapedSafetensors { + safetensors: Vec<yoke::Yoke<SafeTensors_<'static>, memmap2::Mmap>>, + routing: Option<HashMap<String, usize>>, +} + +impl MmapedSafetensors { + /// Creates a wrapper around a memory mapped file and deserialize the safetensors header. + /// + /// # Safety + /// + /// The unsafe is inherited from [`memmap2::MmapOptions`]. + pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> { + let p = p.as_ref(); + let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?; + let file = memmap2::MmapOptions::new() + .map(&file) + .map_err(|e| Error::from(e).with_path(p))?; + let safetensors = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart( + file, + |data: &[u8]| { + let st = safetensors::SafeTensors::deserialize(data) + .map_err(|e| Error::from(e).with_path(p))?; + Ok::<_, Error>(SafeTensors_(st)) + }, + )?; + Ok(Self { + safetensors: vec![safetensors], + routing: None, + }) + } + + /// Creates a wrapper around multiple memory mapped file and deserialize the safetensors headers. + /// + /// If a tensor name appears in multiple files, the last entry is returned. + /// + /// # Safety + /// + /// The unsafe is inherited from [`memmap2::MmapOptions`]. + pub unsafe fn multi<P: AsRef<Path>>(paths: &[P]) -> Result<Self> { + let mut routing = HashMap::new(); + let mut safetensors = vec![]; + for (index, p) in paths.iter().enumerate() { + let p = p.as_ref(); + let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?; + let file = memmap2::MmapOptions::new() + .map(&file) + .map_err(|e| Error::from(e).with_path(p))?; + let data = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart( + file, + |data: &[u8]| { + let st = safetensors::SafeTensors::deserialize(data) + .map_err(|e| Error::from(e).with_path(p))?; + Ok::<_, Error>(SafeTensors_(st)) + }, + )?; + for k in data.get().0.names() { + routing.insert(k.to_string(), index); + } + safetensors.push(data) + } + Ok(Self { + safetensors, + routing: Some(routing), + }) + } + + pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> { + let index = match &self.routing { + None => 0, + Some(routing) => { + let index = routing.get(name).ok_or_else(|| { + Error::CannotFindTensor { + path: name.to_string(), + } + .bt() + })?; + *index + } + }; + self.safetensors[index].get().0.tensor(name)?.load(dev) + } + + pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> { + let mut tensors = vec![]; + for safetensors in self.safetensors.iter() { + tensors.push(safetensors.get().0.tensors()) + } + tensors.into_iter().flatten().collect() + } +} + pub struct MmapedFile { path: std::path::PathBuf, inner: memmap2::Mmap, diff --git a/candle-nn/src/var_map.rs b/candle-nn/src/var_map.rs index c17558b7..d3da84e8 100644 --- a/candle-nn/src/var_map.rs +++ b/candle-nn/src/var_map.rs @@ -1,4 +1,4 @@ -use candle::{safetensors::Load, DType, Device, Result, Shape, Tensor, Var}; +use candle::{DType, Device, Result, Shape, Tensor, Var}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -40,18 +40,12 @@ impl VarMap { /// Note that values for variables that are currently not in the map are not kept. pub fn load<P: AsRef<std::path::Path>>(&mut self, path: P) -> Result<()> { let path = path.as_ref(); - let data = unsafe { candle::safetensors::MmapedFile::new(path)? }; - let data = data.deserialize()?; + let data = unsafe { candle::safetensors::MmapedSafetensors::new(path)? }; let mut tensor_data = self.data.lock().unwrap(); for (name, var) in tensor_data.iter_mut() { - match data.tensor(name) { - Ok(data) => { - let data: Tensor = data.load(var.device())?; - if let Err(err) = var.set(&data) { - candle::bail!("error setting {name} using data from {path:?}: {err}",) - } - } - Err(_) => candle::bail!("cannot find tensor for {name}"), + let data = data.load(name, var.device())?; + if let Err(err) = var.set(&data) { + candle::bail!("error setting {name} using data from {path:?}: {err}",) } } Ok(()) |