summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/safetensors.rs60
-rw-r--r--candle-examples/examples/llama2-c/main.rs3
2 files changed, 61 insertions, 2 deletions
diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs
index 0e1cc655..06b9b23b 100644
--- a/candle-core/src/safetensors.rs
+++ b/candle-core/src/safetensors.rs
@@ -109,6 +109,46 @@ fn convert_slice<T: WithDType>(data: &[u8], shape: &[usize], device: &Device) ->
}
}
+fn convert_slice_with_cast<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
+ data: &[u8],
+ shape: &[usize],
+ device: &Device,
+ conv: F,
+) -> Result<Tensor> {
+ let size_in_bytes = std::mem::size_of::<T>();
+ let elem_count = data.len() / size_in_bytes;
+ if (data.as_ptr() as usize) % size_in_bytes == 0 {
+ // SAFETY This is safe because we just checked that this
+ // was correctly aligned.
+ let data: &[T] =
+ unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
+ let data = data.iter().map(|t| conv(*t)).collect::<Result<Vec<_>>>()?;
+ Tensor::from_vec(data, shape, device)
+ } else {
+ // XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast
+ // Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access
+ let mut c: Vec<T> = Vec::with_capacity(elem_count);
+ // SAFETY: We just created c, so the allocated memory is necessarily
+ // contiguous and non overlapping with the view's data.
+ // We're downgrading the `c` pointer from T to u8, which removes alignment
+ // constraints.
+ unsafe {
+ std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
+ c.set_len(elem_count)
+ }
+ let c = c.into_iter().map(conv).collect::<Result<Vec<_>>>()?;
+ Tensor::from_vec(c, shape, device)
+ }
+}
+
+fn convert_with_cast_<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
+ view: &st::TensorView<'_>,
+ device: &Device,
+ conv: F,
+) -> Result<Tensor> {
+ convert_slice_with_cast::<T, U, F>(view.data(), view.shape(), device, conv)
+}
+
fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
convert_slice::<T>(view.data(), view.shape(), device)
}
@@ -158,11 +198,29 @@ impl Tensor {
fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
match view.dtype() {
st::Dtype::U8 => convert_::<u8>(view, device),
- st::Dtype::U32 => convert_::<u8>(view, device),
+ st::Dtype::U16 => {
+ let conv = |x| Ok(u32::from(x));
+ convert_with_cast_::<u16, u32, _>(view, device, conv)
+ }
+ st::Dtype::U32 => convert_::<u32>(view, device),
st::Dtype::BF16 => convert_::<half::bf16>(view, device),
st::Dtype::F16 => convert_::<half::f16>(view, device),
st::Dtype::F32 => convert_::<f32>(view, device),
st::Dtype::F64 => convert_::<f64>(view, device),
+ st::Dtype::I32 => {
+ let conv = |x| {
+ u32::try_from(x)
+ .map_err(|_| Error::Msg(format!("out of bounds value for u32: {x}")))
+ };
+ convert_with_cast_::<i32, u32, _>(view, device, conv)
+ }
+ st::Dtype::I64 => {
+ let conv = |x| {
+ u32::try_from(x)
+ .map_err(|_| Error::Msg(format!("out of bounds value for u32: {x}")))
+ };
+ convert_with_cast_::<i64, u32, _>(view, device, conv)
+ }
dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
}
}
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs
index 65641b3c..d710652f 100644
--- a/candle-examples/examples/llama2-c/main.rs
+++ b/candle-examples/examples/llama2-c/main.rs
@@ -266,7 +266,8 @@ fn run_eval(tokenizer: Tokenizer, config_path: &std::path::PathBuf, args: Args)
let file = std::io::BufReader::new(file);
let mut tokens = vec![];
for line in file.lines() {
- let line = tokenizer.encode(line?, false).map_err(E::msg)?;
+ let line = line?.replace("<|endoftext|>", "");
+ let line = tokenizer.encode(line, false).map_err(E::msg)?;
tokens.push(line.get_ids().to_vec())
}
let tokens = tokens.concat();