summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
authorlaurent <laurent.mazare@gmail.com>2023-07-03 08:37:46 +0100
committerlaurent <laurent.mazare@gmail.com>2023-07-03 08:37:46 +0100
commitcf2789fb819049cb33a52d73b84a5810cc27cc97 (patch)
treed2944472bcc6557de5b712aeda280bb476adadf3 /candle-core/src
parent9e419641fb5594435ea8f0abd04547db0991c2b2 (diff)
downloadcandle-cf2789fb819049cb33a52d73b84a5810cc27cc97.tar.gz
candle-cf2789fb819049cb33a52d73b84a5810cc27cc97.tar.bz2
candle-cf2789fb819049cb33a52d73b84a5810cc27cc97.zip
Move some safetensors bits in the candle-core crate.
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/lib.rs1
-rw-r--r--candle-core/src/safetensors.rs27
2 files changed, 28 insertions, 0 deletions
diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs
index 6a860116..0d4c2a8d 100644
--- a/candle-core/src/lib.rs
+++ b/candle-core/src/lib.rs
@@ -10,6 +10,7 @@ mod error;
mod layout;
mod npy;
mod op;
+pub mod safetensors;
mod shape;
mod storage;
mod strided_index;
diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs
new file mode 100644
index 00000000..3ed36b64
--- /dev/null
+++ b/candle-core/src/safetensors.rs
@@ -0,0 +1,27 @@
+use crate::{Device, Result, Tensor};
+use half::f16;
+use safetensors::tensor as st;
+
+pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
+ match view.dtype() {
+ st::Dtype::F16 => {
+ let v = view.data();
+ if (v.as_ptr() as usize) % 2 == 0 {
+ // SAFETY This is safe because we just checked that this
+ // was correctly aligned.
+ let data: &[f16] =
+ unsafe { std::slice::from_raw_parts(v.as_ptr() as *const f16, v.len() / 2) };
+ Tensor::from_slice(data, view.shape(), device)
+ } else {
+ let mut c = Vec::with_capacity(v.len() / 2);
+ let mut i = 0;
+ while i < v.len() {
+ c.push(f16::from_le_bytes([v[i], v[i + 1]]));
+ i += 2;
+ }
+ Tensor::from_slice(&c, view.shape(), device)
+ }
+ }
+ dt => todo!("Unhandled dtype {dt:?}"),
+ }
+}