//! Module to load `safetensor` files into CPU/GPU memory. //! //! There are multiple ways to load tensors from safetensor files: //! - `load` function for loading directly into memory and returning a HashMap of tensors //! - `MmapedSafetensors` for memory mapping files and avoiding full allocation //! - `SliceSafetensors` for working with in-memory buffers //! - `BufferedSafetensors` for owning a buffer of data //! //! Tensors can also be serialized to safetensor format using the `save` function or //! `Tensor::save_safetensors` method. //! use crate::{DType, Device, Error, Result, Tensor, WithDType}; use safetensors::tensor as st; use safetensors::tensor::SafeTensors; use std::borrow::Cow; use std::collections::HashMap; use std::path::Path; impl From for st::Dtype { fn from(value: DType) -> Self { match value { DType::U8 => st::Dtype::U8, DType::U32 => st::Dtype::U32, DType::I64 => st::Dtype::I64, DType::BF16 => st::Dtype::BF16, DType::F16 => st::Dtype::F16, DType::F32 => st::Dtype::F32, DType::F64 => st::Dtype::F64, } } } impl TryFrom for DType { type Error = Error; fn try_from(value: st::Dtype) -> Result { match value { st::Dtype::U8 => Ok(DType::U8), st::Dtype::U32 => Ok(DType::U32), st::Dtype::I64 => Ok(DType::I64), st::Dtype::BF16 => Ok(DType::BF16), st::Dtype::F16 => Ok(DType::F16), st::Dtype::F32 => Ok(DType::F32), st::Dtype::F64 => Ok(DType::F64), dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)), } } } impl st::View for Tensor { fn dtype(&self) -> st::Dtype { self.dtype().into() } fn shape(&self) -> &[usize] { self.shape().dims() } fn data(&self) -> Cow<[u8]> { // This copies data from GPU to CPU. // TODO: Avoid the unwrap here. Cow::Owned(convert_back(self).unwrap()) } fn data_len(&self) -> usize { let n: usize = self.shape().elem_count(); let bytes_per_element = self.dtype().size_in_bytes(); n * bytes_per_element } } impl st::View for &Tensor { fn dtype(&self) -> st::Dtype { (*self).dtype().into() } fn shape(&self) -> &[usize] { self.dims() } fn data(&self) -> Cow<[u8]> { // This copies data from GPU to CPU. // TODO: Avoid the unwrap here. Cow::Owned(convert_back(self).unwrap()) } fn data_len(&self) -> usize { let n: usize = self.dims().iter().product(); let bytes_per_element = (*self).dtype().size_in_bytes(); n * bytes_per_element } } impl Tensor { pub fn save_safetensors>(&self, name: &str, filename: P) -> Result<()> { let data = [(name, self.clone())]; Ok(st::serialize_to_file(data, &None, filename.as_ref())?) } } fn convert_slice(data: &[u8], shape: &[usize], device: &Device) -> Result { let size_in_bytes = T::DTYPE.size_in_bytes(); let elem_count = data.len() / size_in_bytes; if (data.as_ptr() as usize) % size_in_bytes == 0 { // SAFETY This is safe because we just checked that this // was correctly aligned. let data: &[T] = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) }; Tensor::from_slice(data, shape, device) } else { // XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast // Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access let mut c: Vec = Vec::with_capacity(elem_count); // SAFETY: We just created c, so the allocated memory is necessarily // contiguous and non overlapping with the view's data. // We're downgrading the `c` pointer from T to u8, which removes alignment // constraints. unsafe { std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len()); c.set_len(elem_count) } Tensor::from_slice(&c, shape, device) } } fn convert_slice_with_cast Result>( data: &[u8], shape: &[usize], device: &Device, conv: F, ) -> Result { let size_in_bytes = std::mem::size_of::(); let elem_count = data.len() / size_in_bytes; if (data.as_ptr() as usize) % size_in_bytes == 0 { // SAFETY This is safe because we just checked that this // was correctly aligned. let data: &[T] = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) }; let data = data.iter().map(|t| conv(*t)).collect::>>()?; Tensor::from_vec(data, shape, device) } else { // XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast // Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access let mut c: Vec = Vec::with_capacity(elem_count); // SAFETY: We just created c, so the allocated memory is necessarily // contiguous and non overlapping with the view's data. // We're downgrading the `c` pointer from T to u8, which removes alignment // constraints. unsafe { std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len()); c.set_len(elem_count) } let c = c.into_iter().map(conv).collect::>>()?; Tensor::from_vec(c, shape, device) } } fn convert_with_cast_ Result>( view: &st::TensorView<'_>, device: &Device, conv: F, ) -> Result { convert_slice_with_cast::(view.data(), view.shape(), device, conv) } fn convert_(view: &st::TensorView<'_>, device: &Device) -> Result { convert_slice::(view.data(), view.shape(), device) } fn convert_back_(mut vs: Vec) -> Vec { let size_in_bytes = T::DTYPE.size_in_bytes(); let length = vs.len() * size_in_bytes; let capacity = vs.capacity() * size_in_bytes; let ptr = vs.as_mut_ptr() as *mut u8; // Don't run the destructor for Vec std::mem::forget(vs); // SAFETY: // // Every T is larger than u8, so there is no issue regarding alignment. // This re-interpret the Vec as a Vec. unsafe { Vec::from_raw_parts(ptr, length, capacity) } } pub trait Load { fn load(&self, device: &Device) -> Result; } impl Load for st::TensorView<'_> { fn load(&self, device: &Device) -> Result { convert(self, device) } } impl Tensor { pub fn from_raw_buffer( data: &[u8], dtype: DType, shape: &[usize], device: &Device, ) -> Result { match dtype { DType::U8 => convert_slice::(data, shape, device), DType::U32 => convert_slice::(data, shape, device), DType::I64 => convert_slice::(data, shape, device), DType::BF16 => convert_slice::(data, shape, device), DType::F16 => convert_slice::(data, shape, device), DType::F32 => convert_slice::(data, shape, device), DType::F64 => convert_slice::(data, shape, device), } } } fn convert(view: &st::TensorView<'_>, device: &Device) -> Result { match view.dtype() { st::Dtype::U8 => convert_::(view, device), st::Dtype::U16 => { let conv = |x| Ok(u32::from(x)); convert_with_cast_::(view, device, conv) } st::Dtype::U32 => convert_::(view, device), st::Dtype::I32 => { let conv = |x| Ok(i64::from(x)); convert_with_cast_::(view, device, conv) } st::Dtype::I64 => convert_::(view, device), st::Dtype::BF16 => convert_::(view, device), st::Dtype::F16 => convert_::(view, device), st::Dtype::F32 => convert_::(view, device), st::Dtype::F64 => convert_::(view, device), dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)), } } fn convert_back(tensor: &Tensor) -> Result> { // TODO: This makes an unnecessary copy when the tensor is on the cpu. let tensor = tensor.flatten_all()?; match tensor.dtype() { DType::U8 => Ok(convert_back_::(tensor.to_vec1()?)), DType::U32 => Ok(convert_back_::(tensor.to_vec1()?)), DType::I64 => Ok(convert_back_::(tensor.to_vec1()?)), DType::F16 => Ok(convert_back_::(tensor.to_vec1()?)), DType::BF16 => Ok(convert_back_::(tensor.to_vec1()?)), DType::F32 => Ok(convert_back_::(tensor.to_vec1()?)), DType::F64 => Ok(convert_back_::(tensor.to_vec1()?)), } } pub fn load>(filename: P, device: &Device) -> Result> { let data = std::fs::read(filename.as_ref())?; load_buffer(&data[..], device) } pub fn load_buffer(data: &[u8], device: &Device) -> Result> { let st = safetensors::SafeTensors::deserialize(data)?; st.tensors() .into_iter() .map(|(name, view)| Ok((name, view.load(device)?))) .collect() } pub fn save + Ord + std::fmt::Display, P: AsRef>( tensors: &HashMap, filename: P, ) -> Result<()> { Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?) } #[derive(yoke::Yokeable)] struct SafeTensors_<'a>(SafeTensors<'a>); pub struct MmapedSafetensors { safetensors: Vec, memmap2::Mmap>>, routing: Option>, } impl MmapedSafetensors { /// Creates a wrapper around a memory mapped file and deserialize the safetensors header. /// /// # Safety /// /// The unsafe is inherited from [`memmap2::MmapOptions`]. pub unsafe fn new>(p: P) -> Result { let p = p.as_ref(); let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?; let file = memmap2::MmapOptions::new() .map(&file) .map_err(|e| Error::from(e).with_path(p))?; let safetensors = yoke::Yoke::, memmap2::Mmap>::try_attach_to_cart( file, |data: &[u8]| { let st = safetensors::SafeTensors::deserialize(data) .map_err(|e| Error::from(e).with_path(p))?; Ok::<_, Error>(SafeTensors_(st)) }, )?; Ok(Self { safetensors: vec![safetensors], routing: None, }) } /// Creates a wrapper around multiple memory mapped file and deserialize the safetensors headers. /// /// If a tensor name appears in multiple files, the last entry is returned. /// /// # Safety /// /// The unsafe is inherited from [`memmap2::MmapOptions`]. pub unsafe fn multi>(paths: &[P]) -> Result { let mut routing = HashMap::new(); let mut safetensors = vec![]; for (index, p) in paths.iter().enumerate() { let p = p.as_ref(); let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?; let file = memmap2::MmapOptions::new() .map(&file) .map_err(|e| Error::from(e).with_path(p))?; let data = yoke::Yoke::, memmap2::Mmap>::try_attach_to_cart( file, |data: &[u8]| { let st = safetensors::SafeTensors::deserialize(data) .map_err(|e| Error::from(e).with_path(p))?; Ok::<_, Error>(SafeTensors_(st)) }, )?; for k in data.get().0.names() { routing.insert(k.to_string(), index); } safetensors.push(data) } Ok(Self { safetensors, routing: Some(routing), }) } pub fn load(&self, name: &str, dev: &Device) -> Result { self.get(name)?.load(dev) } pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> { let mut tensors = vec![]; for safetensors in self.safetensors.iter() { tensors.push(safetensors.get().0.tensors()) } tensors.into_iter().flatten().collect() } pub fn get(&self, name: &str) -> Result> { let index = match &self.routing { None => 0, Some(routing) => { let index = routing.get(name).ok_or_else(|| { Error::CannotFindTensor { path: name.to_string(), } .bt() })?; *index } }; Ok(self.safetensors[index].get().0.tensor(name)?) } } pub struct SliceSafetensors<'a> { safetensors: SafeTensors<'a>, } impl<'a> SliceSafetensors<'a> { /// Creates a wrapper around a binary buffer and deserialize the safetensors header. pub fn new(buffer: &'a [u8]) -> Result { let safetensors = safetensors::SafeTensors::deserialize(buffer)?; Ok(Self { safetensors }) } pub fn load(&self, name: &str, dev: &Device) -> Result { self.safetensors.tensor(name)?.load(dev) } pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> { self.safetensors.tensors() } pub fn get(&self, name: &str) -> Result> { Ok(self.safetensors.tensor(name)?) } } pub struct BufferedSafetensors { safetensors: yoke::Yoke, Vec>, } impl BufferedSafetensors { /// Creates a wrapper around a binary buffer and deserialize the safetensors header. pub fn new(buffer: Vec) -> Result { let safetensors = yoke::Yoke::, Vec>::try_attach_to_cart( buffer, |data: &[u8]| { let st = safetensors::SafeTensors::deserialize(data)?; Ok::<_, Error>(SafeTensors_(st)) }, )?; Ok(Self { safetensors }) } pub fn load(&self, name: &str, dev: &Device) -> Result { self.get(name)?.load(dev) } pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> { self.safetensors.get().0.tensors() } pub fn get(&self, name: &str) -> Result> { Ok(self.safetensors.get().0.tensor(name)?) } } pub struct MmapedFile { path: std::path::PathBuf, inner: memmap2::Mmap, } impl MmapedFile { /// Creates a wrapper around a memory mapped file from which you can retrieve /// tensors using [`MmapedFile::deserialize`] /// /// # Safety /// /// The unsafe is inherited from [`memmap2::MmapOptions`]. pub unsafe fn new>(p: P) -> Result { let p = p.as_ref(); let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?; let inner = memmap2::MmapOptions::new() .map(&file) .map_err(|e| Error::from(e).with_path(p))?; Ok(Self { inner, path: p.to_path_buf(), }) } pub fn deserialize(&self) -> Result> { let st = safetensors::SafeTensors::deserialize(&self.inner) .map_err(|e| Error::from(e).with_path(&self.path))?; Ok(st) } } #[cfg(test)] mod tests { use super::*; use std::collections::HashMap; #[test] fn save_single_tensor() { let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap(); t.save_safetensors("t", "t.safetensors").unwrap(); let bytes = std::fs::read("t.safetensors").unwrap(); assert_eq!(bytes, b"@\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]}} \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0"); std::fs::remove_file("t.safetensors").unwrap(); } #[test] fn save_load_multiple_tensors() { let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap(); let u = Tensor::zeros((1, 2), DType::F32, &Device::Cpu).unwrap(); let map: HashMap<_, _> = [("t", t), ("u", u)].into_iter().collect(); save(&map, "multi.safetensors").unwrap(); let weights = load("multi.safetensors", &Device::Cpu).unwrap(); assert_eq!(weights.get("t").unwrap().dims(), &[2, 2]); assert_eq!(weights.get("u").unwrap().dims(), &[1, 2]); let bytes = std::fs::read("multi.safetensors").unwrap(); assert_eq!(bytes, b"x\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]},\"u\":{\"dtype\":\"F32\",\"shape\":[1,2],\"data_offsets\":[16,24]}} \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0"); std::fs::remove_file("multi.safetensors").unwrap(); } }