summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/examples/llama/weights.rs33
-rw-r--r--candle-core/src/error.rs4
-rw-r--r--candle-core/src/lib.rs1
-rw-r--r--candle-core/src/safetensors.rs27
4 files changed, 34 insertions, 31 deletions
diff --git a/candle-core/examples/llama/weights.rs b/candle-core/examples/llama/weights.rs
index 5eff8e21..4ad9b391 100644
--- a/candle-core/examples/llama/weights.rs
+++ b/candle-core/examples/llama/weights.rs
@@ -1,38 +1,10 @@
use super::*;
use candle::{Device, Result, Tensor};
-use half::f16;
use memmap2::MmapOptions;
-use safetensors::{
- tensor::{Dtype, TensorView},
- SafeTensors,
-};
+use safetensors::SafeTensors;
use std::fs::File;
use std::path::PathBuf;
-fn convert(view: TensorView<'_>, device: &Device) -> Result<Tensor> {
- match view.dtype() {
- Dtype::F16 => {
- let v = view.data();
- if (v.as_ptr() as usize) % 2 == 0 {
- // SAFETY This is safe because we just checked that this
- // was correctly aligned.
- let data: &[f16] =
- unsafe { std::slice::from_raw_parts(v.as_ptr() as *const f16, v.len() / 2) };
- Tensor::from_slice(data, view.shape(), device)?.to_dtype(DTYPE)
- } else {
- let mut c = Vec::with_capacity(v.len() / 2);
- let mut i = 0;
- while i < v.len() {
- c.push(f16::from_le_bytes([v[i], v[i + 1]]));
- i += 2;
- }
- Tensor::from_slice(&c, view.shape(), device)?.to_dtype(DTYPE)
- }
- }
- dt => todo!("Unhandled dtype {dt:?}"),
- }
-}
-
pub struct VarBuilder<'a> {
routing: HashMap<String, usize>,
safetensors: Vec<SafeTensors<'a>>,
@@ -59,8 +31,7 @@ impl<'a> VarBuilder<'a> {
// Unwrap or 0 just to let the proper error flow.
let index = self.routing.get(tensor_name).unwrap_or(&0);
let view = self.safetensors[*index].tensor(tensor_name).unwrap();
- let tensor = convert(view, &self.device)?;
- Ok(tensor)
+ candle::safetensors::convert(view, &self.device)?.to_dtype(DTYPE)
}
}
diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs
index 341fc151..71fd21de 100644
--- a/candle-core/src/error.rs
+++ b/candle-core/src/error.rs
@@ -110,6 +110,10 @@ pub enum Error {
#[error(transparent)]
Io(#[from] std::io::Error),
+ /// SafeTensor error.
+ #[error(transparent)]
+ SafeTensor(#[from] safetensors::SafeTensorError),
+
#[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape },
}
diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs
index 6a860116..0d4c2a8d 100644
--- a/candle-core/src/lib.rs
+++ b/candle-core/src/lib.rs
@@ -10,6 +10,7 @@ mod error;
mod layout;
mod npy;
mod op;
+pub mod safetensors;
mod shape;
mod storage;
mod strided_index;
diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs
new file mode 100644
index 00000000..3ed36b64
--- /dev/null
+++ b/candle-core/src/safetensors.rs
@@ -0,0 +1,27 @@
+use crate::{Device, Result, Tensor};
+use half::f16;
+use safetensors::tensor as st;
+
+pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
+ match view.dtype() {
+ st::Dtype::F16 => {
+ let v = view.data();
+ if (v.as_ptr() as usize) % 2 == 0 {
+ // SAFETY This is safe because we just checked that this
+ // was correctly aligned.
+ let data: &[f16] =
+ unsafe { std::slice::from_raw_parts(v.as_ptr() as *const f16, v.len() / 2) };
+ Tensor::from_slice(data, view.shape(), device)
+ } else {
+ let mut c = Vec::with_capacity(v.len() / 2);
+ let mut i = 0;
+ while i < v.len() {
+ c.push(f16::from_le_bytes([v[i], v[i + 1]]));
+ i += 2;
+ }
+ Tensor::from_slice(&c, view.shape(), device)
+ }
+ }
+ dt => todo!("Unhandled dtype {dt:?}"),
+ }
+}