summaryrefslogtreecommitdiff
path: root/candle-core/examples/llama/weights.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/examples/llama/weights.rs')
-rw-r--r--candle-core/examples/llama/weights.rs33
1 files changed, 2 insertions, 31 deletions
diff --git a/candle-core/examples/llama/weights.rs b/candle-core/examples/llama/weights.rs
index 5eff8e21..4ad9b391 100644
--- a/candle-core/examples/llama/weights.rs
+++ b/candle-core/examples/llama/weights.rs
@@ -1,38 +1,10 @@
use super::*;
use candle::{Device, Result, Tensor};
-use half::f16;
use memmap2::MmapOptions;
-use safetensors::{
- tensor::{Dtype, TensorView},
- SafeTensors,
-};
+use safetensors::SafeTensors;
use std::fs::File;
use std::path::PathBuf;
-fn convert(view: TensorView<'_>, device: &Device) -> Result<Tensor> {
- match view.dtype() {
- Dtype::F16 => {
- let v = view.data();
- if (v.as_ptr() as usize) % 2 == 0 {
- // SAFETY This is safe because we just checked that this
- // was correctly aligned.
- let data: &[f16] =
- unsafe { std::slice::from_raw_parts(v.as_ptr() as *const f16, v.len() / 2) };
- Tensor::from_slice(data, view.shape(), device)?.to_dtype(DTYPE)
- } else {
- let mut c = Vec::with_capacity(v.len() / 2);
- let mut i = 0;
- while i < v.len() {
- c.push(f16::from_le_bytes([v[i], v[i + 1]]));
- i += 2;
- }
- Tensor::from_slice(&c, view.shape(), device)?.to_dtype(DTYPE)
- }
- }
- dt => todo!("Unhandled dtype {dt:?}"),
- }
-}
-
pub struct VarBuilder<'a> {
routing: HashMap<String, usize>,
safetensors: Vec<SafeTensors<'a>>,
@@ -59,8 +31,7 @@ impl<'a> VarBuilder<'a> {
// Unwrap or 0 just to let the proper error flow.
let index = self.routing.get(tensor_name).unwrap_or(&0);
let view = self.safetensors[*index].tensor(tensor_name).unwrap();
- let tensor = convert(view, &self.device)?;
- Ok(tensor)
+ candle::safetensors::convert(view, &self.device)?.to_dtype(DTYPE)
}
}