summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/Cargo.toml14
-rw-r--r--candle-core/examples/llama/weights.rs24
-rw-r--r--candle-core/src/error.rs3
-rw-r--r--candle-core/src/safetensors.rs91
4 files changed, 88 insertions, 44 deletions
diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml
index 1b7ef4c4..7076e4e4 100644
--- a/candle-core/Cargo.toml
+++ b/candle-core/Cargo.toml
@@ -11,25 +11,25 @@ license = "MIT/Apache-2.0"
readme = "README.md"
[dependencies]
-safetensors = "0.3.1"
-thiserror = "1"
-cudarc = { version = "0.9.9", optional = true, features = ["f16"] }
+byteorder = "1.4.3"
candle-kernels = { path = "../candle-kernels", optional = true }
+cudarc = { version = "0.9.9", optional = true, features = ["f16"] }
gemm = "0.15.4"
-zip = { version = "0.6.6", default-features=false }
-byteorder = "1.4.3"
half = { version = "2.3.1", features = ["num-traits"] }
+memmap2 = "0.7.1"
num-traits = "0.2.15"
num_cpus = "1.15.0"
+safetensors = "0.3.1"
+thiserror = "1"
+zip = { version = "0.6.6", default-features=false }
[dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
+candle-hub = { path = "../candle-hub" }
clap = { version = "4.2.4", features = ["derive"] }
rand = "0.8.5"
tokenizers = { version = "0.13.3", default-features=false, features=["onig"] }
tokio = { version = "1.28.2", features = ["macros", "rt-multi-thread"] }
-candle-hub = { path = "../candle-hub" }
-memmap2 = "0.7.1"
[features]
default = ["cuda"]
diff --git a/candle-core/examples/llama/weights.rs b/candle-core/examples/llama/weights.rs
index 4ad9b391..cc3fccd4 100644
--- a/candle-core/examples/llama/weights.rs
+++ b/candle-core/examples/llama/weights.rs
@@ -1,8 +1,5 @@
use super::*;
-use candle::{Device, Result, Tensor};
-use memmap2::MmapOptions;
-use safetensors::SafeTensors;
-use std::fs::File;
+use candle::{safetensors::SafeTensors, Device, Result, Tensor};
use std::path::PathBuf;
pub struct VarBuilder<'a> {
@@ -30,8 +27,9 @@ impl<'a> VarBuilder<'a> {
pub fn get(&self, tensor_name: &str) -> Result<Tensor> {
// Unwrap or 0 just to let the proper error flow.
let index = self.routing.get(tensor_name).unwrap_or(&0);
- let view = self.safetensors[*index].tensor(tensor_name).unwrap();
- candle::safetensors::convert(view, &self.device)?.to_dtype(DTYPE)
+ self.safetensors[*index]
+ .tensor(tensor_name, &self.device)?
+ .to_dtype(DTYPE)
}
}
@@ -107,18 +105,12 @@ impl Llama {
) -> Result<Self> {
let handles: Vec<_> = filenames
.iter()
- .map(|f| {
- let file = File::open(f).unwrap();
- unsafe { MmapOptions::new().map(&file).unwrap() }
- })
- .collect();
+ .map(candle::safetensors::MmapedFile::new)
+ .collect::<Result<Vec<_>>>()?;
let tensors: Vec<_> = handles
.iter()
- .map(|h| {
- let tensors = SafeTensors::deserialize(h).unwrap();
- tensors
- })
- .collect();
+ .map(|h| h.deserialize())
+ .collect::<Result<Vec<_>>>()?;
let vb = VarBuilder::new(tensors, device.clone());
diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs
index 71fd21de..d5de4296 100644
--- a/candle-core/src/error.rs
+++ b/candle-core/src/error.rs
@@ -114,6 +114,9 @@ pub enum Error {
#[error(transparent)]
SafeTensor(#[from] safetensors::SafeTensorError),
+ #[error("unsupported safetensor dtype {0:?}")]
+ UnsupportedSafeTensorDtype(safetensors::Dtype),
+
#[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape },
}
diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs
index 3ed36b64..b80a756a 100644
--- a/candle-core/src/safetensors.rs
+++ b/candle-core/src/safetensors.rs
@@ -1,27 +1,76 @@
-use crate::{Device, Result, Tensor};
-use half::f16;
+use crate::{Device, Error, Result, Tensor, WithDType};
use safetensors::tensor as st;
+fn convert_<T: WithDType>(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
+ let v = view.data();
+ let size_in_bytes = T::DTYPE.size_in_bytes();
+ let elem_count = v.len() / size_in_bytes;
+ if (v.as_ptr() as usize) % size_in_bytes == 0 {
+ // SAFETY This is safe because we just checked that this
+ // was correctly aligned.
+ let data: &[T] = unsafe { std::slice::from_raw_parts(v.as_ptr() as *const T, elem_count) };
+ Tensor::from_slice(data, view.shape(), device)
+ } else {
+ let mut c = Vec::with_capacity(elem_count);
+ unsafe {
+ std::ptr::copy_nonoverlapping(v.as_ptr(), c.as_mut_ptr() as *mut u8, v.len());
+ c.set_len(elem_count)
+ }
+ Tensor::from_slice(&c, view.shape(), device)
+ }
+}
+
pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
match view.dtype() {
- st::Dtype::F16 => {
- let v = view.data();
- if (v.as_ptr() as usize) % 2 == 0 {
- // SAFETY This is safe because we just checked that this
- // was correctly aligned.
- let data: &[f16] =
- unsafe { std::slice::from_raw_parts(v.as_ptr() as *const f16, v.len() / 2) };
- Tensor::from_slice(data, view.shape(), device)
- } else {
- let mut c = Vec::with_capacity(v.len() / 2);
- let mut i = 0;
- while i < v.len() {
- c.push(f16::from_le_bytes([v[i], v[i + 1]]));
- i += 2;
- }
- Tensor::from_slice(&c, view.shape(), device)
- }
- }
- dt => todo!("Unhandled dtype {dt:?}"),
+ st::Dtype::U8 => convert_::<u8>(view, device),
+ st::Dtype::U32 => convert_::<u8>(view, device),
+ st::Dtype::BF16 => convert_::<half::bf16>(view, device),
+ st::Dtype::F16 => convert_::<half::f16>(view, device),
+ st::Dtype::F32 => convert_::<f32>(view, device),
+ st::Dtype::F64 => convert_::<f64>(view, device),
+ dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
+ }
+}
+
+// If Rust allowed for self-referential struct, we could store both the Mmap buffer and the
+// SafeTensor bits in the same struct and avoid having the final users calling two methods.
+// We could try using the ouroboros crate or equivalent for this at some point.
+// Wrap the SafeTensors main module so as to provide accessors with the candle types for errors,
+// dtypes, etc
+pub struct SafeTensors<'a>(st::SafeTensors<'a>);
+
+pub struct MmapedFile(memmap2::Mmap);
+
+impl MmapedFile {
+ pub fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
+ let file = std::fs::File::open(p)?;
+ let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? };
+ Ok(Self(mmap))
+ }
+
+ pub fn deserialize(&self) -> Result<SafeTensors<'_>> {
+ let st = safetensors::SafeTensors::deserialize(&self.0)?;
+ Ok(SafeTensors(st))
+ }
+}
+
+impl<'a> SafeTensors<'a> {
+ pub fn tensor(&self, name: &str, device: &Device) -> Result<Tensor> {
+ convert(self.0.tensor(name)?, device)
+ }
+
+ pub fn tensors(&self, device: &Device) -> Result<Vec<(String, Tensor)>> {
+ self.0
+ .tensors()
+ .into_iter()
+ .map(|(name, tensor_view)| {
+ let tensor = convert(tensor_view, device)?;
+ Ok((name, tensor))
+ })
+ .collect()
+ }
+
+ pub fn names(&self) -> Vec<&String> {
+ self.0.names()
}
}