diff options
Diffstat (limited to 'candle-core/src/quantized/ggml_file.rs')
-rw-r--r-- | candle-core/src/quantized/ggml_file.rs | 131 |
1 files changed, 26 insertions, 105 deletions
diff --git a/candle-core/src/quantized/ggml_file.rs b/candle-core/src/quantized/ggml_file.rs index 2824f075..ee23cdde 100644 --- a/candle-core/src/quantized/ggml_file.rs +++ b/candle-core/src/quantized/ggml_file.rs @@ -1,7 +1,7 @@ //! Support for the GGML file format. use super::{k_quants, GgmlDType}; -use crate::{DType, Device, Result, Tensor}; +use crate::Result; use byteorder::{LittleEndian, ReadBytesExt}; // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.h#L37 @@ -116,121 +116,47 @@ impl Vocab { } } -fn dequantize_and_create_tensor<T: super::GgmlType>( +fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>( raw_data: &[u8], - tensor_elems: usize, size_in_bytes: usize, dims: Vec<usize>, - device: &Device, -) -> Result<Tensor> { - let mut f32_data = vec![0f32; tensor_elems]; +) -> Result<super::QTensor> { let raw_data_ptr = raw_data.as_ptr(); let n_blocks = size_in_bytes / std::mem::size_of::<T>(); - let raw_data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) }; - T::to_float(raw_data, &mut f32_data)?; - Tensor::from_vec(f32_data, dims, device) + let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) }; + Ok(super::QTensor::new(data.to_vec(), dims)) } /// Creates a [Tensor] from a raw GGML tensor. -pub fn tensor_from_ggml( +pub fn qtensor_from_ggml( ggml_dtype: GgmlDType, raw_data: &[u8], dims: Vec<usize>, - dtype: DType, - device: &Device, -) -> Result<Tensor> { +) -> Result<super::QTensor> { let tensor_elems = dims.iter().product::<usize>(); let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size(); - let tensor = match ggml_dtype { - GgmlDType::F32 => Tensor::from_raw_buffer(raw_data, DType::F32, &dims, device), - GgmlDType::F16 => Tensor::from_raw_buffer(raw_data, DType::F16, &dims, device), - GgmlDType::Q4_0 => dequantize_and_create_tensor::<k_quants::BlockQ4_0>( - raw_data, - tensor_elems, - size_in_bytes, - dims, - device, - ), - GgmlDType::Q4_1 => dequantize_and_create_tensor::<k_quants::BlockQ4_1>( - raw_data, - tensor_elems, - size_in_bytes, - dims, - device, - ), - GgmlDType::Q5_0 => dequantize_and_create_tensor::<k_quants::BlockQ5_0>( - raw_data, - tensor_elems, - size_in_bytes, - dims, - device, - ), - GgmlDType::Q5_1 => dequantize_and_create_tensor::<k_quants::BlockQ5_1>( - raw_data, - tensor_elems, - size_in_bytes, - dims, - device, - ), - GgmlDType::Q8_0 => dequantize_and_create_tensor::<k_quants::BlockQ8_0>( - raw_data, - tensor_elems, - size_in_bytes, - dims, - device, - ), - GgmlDType::Q2K => dequantize_and_create_tensor::<k_quants::BlockQ2K>( - raw_data, - tensor_elems, - size_in_bytes, - dims, - device, - ), - GgmlDType::Q3K => dequantize_and_create_tensor::<k_quants::BlockQ3K>( - raw_data, - tensor_elems, - size_in_bytes, - dims, - device, - ), - GgmlDType::Q4K => dequantize_and_create_tensor::<k_quants::BlockQ4K>( - raw_data, - tensor_elems, - size_in_bytes, - dims, - device, - ), - GgmlDType::Q5K => dequantize_and_create_tensor::<k_quants::BlockQ5K>( - raw_data, - tensor_elems, - size_in_bytes, - dims, - device, - ), - GgmlDType::Q6K => dequantize_and_create_tensor::<k_quants::BlockQ6K>( - raw_data, - tensor_elems, - size_in_bytes, - dims, - device, - ), - _ => crate::bail!("quantized type {dtype:?} is not supported yet"), - }?; - //We only have ggml-quant to f32 conversions, meaning we have to convert to the desired type - if tensor.dtype() != dtype { - tensor.to_dtype(dtype) - } else { - Ok(tensor) + match ggml_dtype { + GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims), + GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims), + GgmlDType::Q4_0 => from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims), + GgmlDType::Q4_1 => from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims), + GgmlDType::Q5_0 => from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims), + GgmlDType::Q5_1 => from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims), + GgmlDType::Q8_0 => from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims), + GgmlDType::Q2K => from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims), + GgmlDType::Q3K => from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims), + GgmlDType::Q4K => from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims), + GgmlDType::Q5K => from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims), + GgmlDType::Q6K => from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims), + _ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"), } } fn read_one_tensor<R: std::io::Seek + std::io::Read>( reader: &mut R, magic: VersionedMagic, - dtype: DType, - device: &Device, -) -> Result<(String, Tensor)> { +) -> Result<(String, super::QTensor)> { let n_dims = reader.read_u32::<LittleEndian>()?; let name_len = reader.read_u32::<LittleEndian>()?; let ggml_dtype = reader.read_u32::<LittleEndian>()?; @@ -252,26 +178,21 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>( // TODO: Mmap version to avoid copying the data around? let mut raw_data = vec![0u8; size_in_bytes]; reader.read_exact(&mut raw_data)?; - match tensor_from_ggml(ggml_dtype, &raw_data, dims, dtype, device) { + match qtensor_from_ggml(ggml_dtype, &raw_data, dims) { Ok(tensor) => Ok((name, tensor)), Err(e) => crate::bail!("Error creating tensor {name}: {e}"), } } -#[derive(Debug)] pub struct Content { pub magic: VersionedMagic, pub hparams: HParams, pub vocab: Vocab, - pub tensors: Vec<(String, Tensor)>, + pub tensors: Vec<(String, super::QTensor)>, } impl Content { - pub fn read<R: std::io::Seek + std::io::Read>( - reader: &mut R, - dtype: DType, - device: &Device, - ) -> Result<Content> { + pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Content> { // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505 let last_position = reader.seek(std::io::SeekFrom::End(0))?; reader.seek(std::io::SeekFrom::Start(0))?; @@ -281,7 +202,7 @@ impl Content { let mut tensors = vec![]; while reader.stream_position()? != last_position { - let (name, tensor) = read_one_tensor(reader, magic, dtype, device)?; + let (name, tensor) = read_one_tensor(reader, magic)?; tensors.push((name, tensor)) } Ok(Self { |