diff options
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/error.rs | 4 | ||||
-rw-r--r-- | candle-core/src/lib.rs | 1 | ||||
-rw-r--r-- | candle-core/src/safetensors.rs | 27 |
3 files changed, 32 insertions, 0 deletions
diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index 341fc151..71fd21de 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -110,6 +110,10 @@ pub enum Error { #[error(transparent)] Io(#[from] std::io::Error), + /// SafeTensor error. + #[error(transparent)] + SafeTensor(#[from] safetensors::SafeTensorError), + #[error("cannot broadcast {src_shape:?} to {dst_shape:?}")] BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape }, } diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 6a860116..0d4c2a8d 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -10,6 +10,7 @@ mod error; mod layout; mod npy; mod op; +pub mod safetensors; mod shape; mod storage; mod strided_index; diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs new file mode 100644 index 00000000..3ed36b64 --- /dev/null +++ b/candle-core/src/safetensors.rs @@ -0,0 +1,27 @@ +use crate::{Device, Result, Tensor}; +use half::f16; +use safetensors::tensor as st; + +pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> { + match view.dtype() { + st::Dtype::F16 => { + let v = view.data(); + if (v.as_ptr() as usize) % 2 == 0 { + // SAFETY This is safe because we just checked that this + // was correctly aligned. + let data: &[f16] = + unsafe { std::slice::from_raw_parts(v.as_ptr() as *const f16, v.len() / 2) }; + Tensor::from_slice(data, view.shape(), device) + } else { + let mut c = Vec::with_capacity(v.len() / 2); + let mut i = 0; + while i < v.len() { + c.push(f16::from_le_bytes([v[i], v[i + 1]])); + i += 2; + } + Tensor::from_slice(&c, view.shape(), device) + } + } + dt => todo!("Unhandled dtype {dt:?}"), + } +} |