diff options
author | laurent <laurent.mazare@gmail.com> | 2023-07-03 08:37:46 +0100 |
---|---|---|
committer | laurent <laurent.mazare@gmail.com> | 2023-07-03 08:37:46 +0100 |
commit | cf2789fb819049cb33a52d73b84a5810cc27cc97 (patch) | |
tree | d2944472bcc6557de5b712aeda280bb476adadf3 | |
parent | 9e419641fb5594435ea8f0abd04547db0991c2b2 (diff) | |
download | candle-cf2789fb819049cb33a52d73b84a5810cc27cc97.tar.gz candle-cf2789fb819049cb33a52d73b84a5810cc27cc97.tar.bz2 candle-cf2789fb819049cb33a52d73b84a5810cc27cc97.zip |
Move some safetensors bits in the candle-core crate.
-rw-r--r-- | candle-core/examples/llama/weights.rs | 33 | ||||
-rw-r--r-- | candle-core/src/lib.rs | 1 | ||||
-rw-r--r-- | candle-core/src/safetensors.rs | 27 |
3 files changed, 30 insertions, 31 deletions
diff --git a/candle-core/examples/llama/weights.rs b/candle-core/examples/llama/weights.rs index 5eff8e21..4ad9b391 100644 --- a/candle-core/examples/llama/weights.rs +++ b/candle-core/examples/llama/weights.rs @@ -1,38 +1,10 @@ use super::*; use candle::{Device, Result, Tensor}; -use half::f16; use memmap2::MmapOptions; -use safetensors::{ - tensor::{Dtype, TensorView}, - SafeTensors, -}; +use safetensors::SafeTensors; use std::fs::File; use std::path::PathBuf; -fn convert(view: TensorView<'_>, device: &Device) -> Result<Tensor> { - match view.dtype() { - Dtype::F16 => { - let v = view.data(); - if (v.as_ptr() as usize) % 2 == 0 { - // SAFETY This is safe because we just checked that this - // was correctly aligned. - let data: &[f16] = - unsafe { std::slice::from_raw_parts(v.as_ptr() as *const f16, v.len() / 2) }; - Tensor::from_slice(data, view.shape(), device)?.to_dtype(DTYPE) - } else { - let mut c = Vec::with_capacity(v.len() / 2); - let mut i = 0; - while i < v.len() { - c.push(f16::from_le_bytes([v[i], v[i + 1]])); - i += 2; - } - Tensor::from_slice(&c, view.shape(), device)?.to_dtype(DTYPE) - } - } - dt => todo!("Unhandled dtype {dt:?}"), - } -} - pub struct VarBuilder<'a> { routing: HashMap<String, usize>, safetensors: Vec<SafeTensors<'a>>, @@ -59,8 +31,7 @@ impl<'a> VarBuilder<'a> { // Unwrap or 0 just to let the proper error flow. let index = self.routing.get(tensor_name).unwrap_or(&0); let view = self.safetensors[*index].tensor(tensor_name).unwrap(); - let tensor = convert(view, &self.device)?; - Ok(tensor) + candle::safetensors::convert(view, &self.device)?.to_dtype(DTYPE) } } diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 6a860116..0d4c2a8d 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -10,6 +10,7 @@ mod error; mod layout; mod npy; mod op; +pub mod safetensors; mod shape; mod storage; mod strided_index; diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs new file mode 100644 index 00000000..3ed36b64 --- /dev/null +++ b/candle-core/src/safetensors.rs @@ -0,0 +1,27 @@ +use crate::{Device, Result, Tensor}; +use half::f16; +use safetensors::tensor as st; + +pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> { + match view.dtype() { + st::Dtype::F16 => { + let v = view.data(); + if (v.as_ptr() as usize) % 2 == 0 { + // SAFETY This is safe because we just checked that this + // was correctly aligned. + let data: &[f16] = + unsafe { std::slice::from_raw_parts(v.as_ptr() as *const f16, v.len() / 2) }; + Tensor::from_slice(data, view.shape(), device) + } else { + let mut c = Vec::with_capacity(v.len() / 2); + let mut i = 0; + while i < v.len() { + c.push(f16::from_le_bytes([v[i], v[i + 1]])); + i += 2; + } + Tensor::from_slice(&c, view.shape(), device) + } + } + dt => todo!("Unhandled dtype {dt:?}"), + } +} |