diff options
-rw-r--r-- | candle-core/src/safetensors.rs | 60 | ||||
-rw-r--r-- | candle-examples/examples/llama2-c/main.rs | 3 |
2 files changed, 61 insertions, 2 deletions
diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index 0e1cc655..06b9b23b 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -109,6 +109,46 @@ fn convert_slice<T: WithDType>(data: &[u8], shape: &[usize], device: &Device) -> } } +fn convert_slice_with_cast<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>( + data: &[u8], + shape: &[usize], + device: &Device, + conv: F, +) -> Result<Tensor> { + let size_in_bytes = std::mem::size_of::<T>(); + let elem_count = data.len() / size_in_bytes; + if (data.as_ptr() as usize) % size_in_bytes == 0 { + // SAFETY This is safe because we just checked that this + // was correctly aligned. + let data: &[T] = + unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) }; + let data = data.iter().map(|t| conv(*t)).collect::<Result<Vec<_>>>()?; + Tensor::from_vec(data, shape, device) + } else { + // XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast + // Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access + let mut c: Vec<T> = Vec::with_capacity(elem_count); + // SAFETY: We just created c, so the allocated memory is necessarily + // contiguous and non overlapping with the view's data. + // We're downgrading the `c` pointer from T to u8, which removes alignment + // constraints. + unsafe { + std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len()); + c.set_len(elem_count) + } + let c = c.into_iter().map(conv).collect::<Result<Vec<_>>>()?; + Tensor::from_vec(c, shape, device) + } +} + +fn convert_with_cast_<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>( + view: &st::TensorView<'_>, + device: &Device, + conv: F, +) -> Result<Tensor> { + convert_slice_with_cast::<T, U, F>(view.data(), view.shape(), device, conv) +} + fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> { convert_slice::<T>(view.data(), view.shape(), device) } @@ -158,11 +198,29 @@ impl Tensor { fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> { match view.dtype() { st::Dtype::U8 => convert_::<u8>(view, device), - st::Dtype::U32 => convert_::<u8>(view, device), + st::Dtype::U16 => { + let conv = |x| Ok(u32::from(x)); + convert_with_cast_::<u16, u32, _>(view, device, conv) + } + st::Dtype::U32 => convert_::<u32>(view, device), st::Dtype::BF16 => convert_::<half::bf16>(view, device), st::Dtype::F16 => convert_::<half::f16>(view, device), st::Dtype::F32 => convert_::<f32>(view, device), st::Dtype::F64 => convert_::<f64>(view, device), + st::Dtype::I32 => { + let conv = |x| { + u32::try_from(x) + .map_err(|_| Error::Msg(format!("out of bounds value for u32: {x}"))) + }; + convert_with_cast_::<i32, u32, _>(view, device, conv) + } + st::Dtype::I64 => { + let conv = |x| { + u32::try_from(x) + .map_err(|_| Error::Msg(format!("out of bounds value for u32: {x}"))) + }; + convert_with_cast_::<i64, u32, _>(view, device, conv) + } dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)), } } diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index 65641b3c..d710652f 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -266,7 +266,8 @@ fn run_eval(tokenizer: Tokenizer, config_path: &std::path::PathBuf, args: Args) let file = std::io::BufReader::new(file); let mut tokens = vec![]; for line in file.lines() { - let line = tokenizer.encode(line?, false).map_err(E::msg)?; + let line = line?.replace("<|endoftext|>", ""); + let line = tokenizer.encode(line, false).map_err(E::msg)?; tokens.push(line.get_ids().to_vec()) } let tokens = tokens.concat(); |