summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/error.rs4
-rw-r--r--candle-core/src/lib.rs1
-rw-r--r--candle-core/src/safetensors.rs27
3 files changed, 32 insertions, 0 deletions
diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs
index 341fc151..71fd21de 100644
--- a/candle-core/src/error.rs
+++ b/candle-core/src/error.rs
@@ -110,6 +110,10 @@ pub enum Error {
#[error(transparent)]
Io(#[from] std::io::Error),
+ /// SafeTensor error.
+ #[error(transparent)]
+ SafeTensor(#[from] safetensors::SafeTensorError),
+
#[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape },
}
diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs
index 6a860116..0d4c2a8d 100644
--- a/candle-core/src/lib.rs
+++ b/candle-core/src/lib.rs
@@ -10,6 +10,7 @@ mod error;
mod layout;
mod npy;
mod op;
+pub mod safetensors;
mod shape;
mod storage;
mod strided_index;
diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs
new file mode 100644
index 00000000..3ed36b64
--- /dev/null
+++ b/candle-core/src/safetensors.rs
@@ -0,0 +1,27 @@
+use crate::{Device, Result, Tensor};
+use half::f16;
+use safetensors::tensor as st;
+
+pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
+ match view.dtype() {
+ st::Dtype::F16 => {
+ let v = view.data();
+ if (v.as_ptr() as usize) % 2 == 0 {
+ // SAFETY This is safe because we just checked that this
+ // was correctly aligned.
+ let data: &[f16] =
+ unsafe { std::slice::from_raw_parts(v.as_ptr() as *const f16, v.len() / 2) };
+ Tensor::from_slice(data, view.shape(), device)
+ } else {
+ let mut c = Vec::with_capacity(v.len() / 2);
+ let mut i = 0;
+ while i < v.len() {
+ c.push(f16::from_le_bytes([v[i], v[i + 1]]));
+ i += 2;
+ }
+ Tensor::from_slice(&c, view.shape(), device)
+ }
+ }
+ dt => todo!("Unhandled dtype {dt:?}"),
+ }
+}