summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-15 22:45:53 +0100
committerGitHub <noreply@github.com>2023-08-15 22:45:53 +0100
commitca449f9ee11b892e026972d114c77a0938e1dc0b (patch)
treefd179ab9ffd5d1a3da740506a091147df9ba39e5 /candle-core/src
parentb8263aa15cf2d8d0f425e25bae296ea4e96aeb88 (diff)
downloadcandle-ca449f9ee11b892e026972d114c77a0938e1dc0b.tar.gz
candle-ca449f9ee11b892e026972d114c77a0938e1dc0b.tar.bz2
candle-ca449f9ee11b892e026972d114c77a0938e1dc0b.zip
Add quantized tensors. (#458)
* Add quantized tensors. * Implement the debug trait for QTensor. * Add the QMatMul custom op.
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/quantized/ggml_file.rs131
-rw-r--r--candle-core/src/quantized/mod.rs114
2 files changed, 139 insertions, 106 deletions
diff --git a/candle-core/src/quantized/ggml_file.rs b/candle-core/src/quantized/ggml_file.rs
index 2824f075..ee23cdde 100644
--- a/candle-core/src/quantized/ggml_file.rs
+++ b/candle-core/src/quantized/ggml_file.rs
@@ -1,7 +1,7 @@
//! Support for the GGML file format.
use super::{k_quants, GgmlDType};
-use crate::{DType, Device, Result, Tensor};
+use crate::Result;
use byteorder::{LittleEndian, ReadBytesExt};
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.h#L37
@@ -116,121 +116,47 @@ impl Vocab {
}
}
-fn dequantize_and_create_tensor<T: super::GgmlType>(
+fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
raw_data: &[u8],
- tensor_elems: usize,
size_in_bytes: usize,
dims: Vec<usize>,
- device: &Device,
-) -> Result<Tensor> {
- let mut f32_data = vec![0f32; tensor_elems];
+) -> Result<super::QTensor> {
let raw_data_ptr = raw_data.as_ptr();
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
- let raw_data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
- T::to_float(raw_data, &mut f32_data)?;
- Tensor::from_vec(f32_data, dims, device)
+ let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
+ Ok(super::QTensor::new(data.to_vec(), dims))
}
/// Creates a [Tensor] from a raw GGML tensor.
-pub fn tensor_from_ggml(
+pub fn qtensor_from_ggml(
ggml_dtype: GgmlDType,
raw_data: &[u8],
dims: Vec<usize>,
- dtype: DType,
- device: &Device,
-) -> Result<Tensor> {
+) -> Result<super::QTensor> {
let tensor_elems = dims.iter().product::<usize>();
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
- let tensor = match ggml_dtype {
- GgmlDType::F32 => Tensor::from_raw_buffer(raw_data, DType::F32, &dims, device),
- GgmlDType::F16 => Tensor::from_raw_buffer(raw_data, DType::F16, &dims, device),
- GgmlDType::Q4_0 => dequantize_and_create_tensor::<k_quants::BlockQ4_0>(
- raw_data,
- tensor_elems,
- size_in_bytes,
- dims,
- device,
- ),
- GgmlDType::Q4_1 => dequantize_and_create_tensor::<k_quants::BlockQ4_1>(
- raw_data,
- tensor_elems,
- size_in_bytes,
- dims,
- device,
- ),
- GgmlDType::Q5_0 => dequantize_and_create_tensor::<k_quants::BlockQ5_0>(
- raw_data,
- tensor_elems,
- size_in_bytes,
- dims,
- device,
- ),
- GgmlDType::Q5_1 => dequantize_and_create_tensor::<k_quants::BlockQ5_1>(
- raw_data,
- tensor_elems,
- size_in_bytes,
- dims,
- device,
- ),
- GgmlDType::Q8_0 => dequantize_and_create_tensor::<k_quants::BlockQ8_0>(
- raw_data,
- tensor_elems,
- size_in_bytes,
- dims,
- device,
- ),
- GgmlDType::Q2K => dequantize_and_create_tensor::<k_quants::BlockQ2K>(
- raw_data,
- tensor_elems,
- size_in_bytes,
- dims,
- device,
- ),
- GgmlDType::Q3K => dequantize_and_create_tensor::<k_quants::BlockQ3K>(
- raw_data,
- tensor_elems,
- size_in_bytes,
- dims,
- device,
- ),
- GgmlDType::Q4K => dequantize_and_create_tensor::<k_quants::BlockQ4K>(
- raw_data,
- tensor_elems,
- size_in_bytes,
- dims,
- device,
- ),
- GgmlDType::Q5K => dequantize_and_create_tensor::<k_quants::BlockQ5K>(
- raw_data,
- tensor_elems,
- size_in_bytes,
- dims,
- device,
- ),
- GgmlDType::Q6K => dequantize_and_create_tensor::<k_quants::BlockQ6K>(
- raw_data,
- tensor_elems,
- size_in_bytes,
- dims,
- device,
- ),
- _ => crate::bail!("quantized type {dtype:?} is not supported yet"),
- }?;
- //We only have ggml-quant to f32 conversions, meaning we have to convert to the desired type
- if tensor.dtype() != dtype {
- tensor.to_dtype(dtype)
- } else {
- Ok(tensor)
+ match ggml_dtype {
+ GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims),
+ GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims),
+ GgmlDType::Q4_0 => from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims),
+ GgmlDType::Q4_1 => from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims),
+ GgmlDType::Q5_0 => from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims),
+ GgmlDType::Q5_1 => from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims),
+ GgmlDType::Q8_0 => from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims),
+ GgmlDType::Q2K => from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims),
+ GgmlDType::Q3K => from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims),
+ GgmlDType::Q4K => from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims),
+ GgmlDType::Q5K => from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims),
+ GgmlDType::Q6K => from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims),
+ _ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
}
}
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
reader: &mut R,
magic: VersionedMagic,
- dtype: DType,
- device: &Device,
-) -> Result<(String, Tensor)> {
+) -> Result<(String, super::QTensor)> {
let n_dims = reader.read_u32::<LittleEndian>()?;
let name_len = reader.read_u32::<LittleEndian>()?;
let ggml_dtype = reader.read_u32::<LittleEndian>()?;
@@ -252,26 +178,21 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>(
// TODO: Mmap version to avoid copying the data around?
let mut raw_data = vec![0u8; size_in_bytes];
reader.read_exact(&mut raw_data)?;
- match tensor_from_ggml(ggml_dtype, &raw_data, dims, dtype, device) {
+ match qtensor_from_ggml(ggml_dtype, &raw_data, dims) {
Ok(tensor) => Ok((name, tensor)),
Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
}
}
-#[derive(Debug)]
pub struct Content {
pub magic: VersionedMagic,
pub hparams: HParams,
pub vocab: Vocab,
- pub tensors: Vec<(String, Tensor)>,
+ pub tensors: Vec<(String, super::QTensor)>,
}
impl Content {
- pub fn read<R: std::io::Seek + std::io::Read>(
- reader: &mut R,
- dtype: DType,
- device: &Device,
- ) -> Result<Content> {
+ pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Content> {
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
reader.seek(std::io::SeekFrom::Start(0))?;
@@ -281,7 +202,7 @@ impl Content {
let mut tensors = vec![];
while reader.stream_position()? != last_position {
- let (name, tensor) = read_one_tensor(reader, magic, dtype, device)?;
+ let (name, tensor) = read_one_tensor(reader, magic)?;
tensors.push((name, tensor))
}
Ok(Self {
diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs
index c7e24592..842b519b 100644
--- a/candle-core/src/quantized/mod.rs
+++ b/candle-core/src/quantized/mod.rs
@@ -1,10 +1,15 @@
-use crate::Result;
+use crate::{Device, Result, Shape, Tensor};
pub mod ggml_file;
pub mod k_quants;
pub use k_quants::GgmlType;
+pub struct QTensor {
+ data: Box<dyn QuantizedType>,
+ shape: Shape,
+}
+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GgmlDType {
F32,
@@ -80,3 +85,110 @@ impl GgmlDType {
}
}
}
+
+// A version of GgmlType without `vec_dot` so that it can be dyn boxed.
+pub trait QuantizedType: Send + Sync {
+ fn dtype(&self) -> GgmlDType;
+ fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
+ fn to_float(&self, ys: &mut [f32]) -> Result<()>;
+}
+
+impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
+ fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
+ k_quants::matmul(mkn, lhs, self.as_slice(), dst)
+ }
+
+ fn dtype(&self) -> GgmlDType {
+ T::DTYPE
+ }
+
+ fn to_float(&self, ys: &mut [f32]) -> Result<()> {
+ T::to_float(self.as_slice(), ys)
+ }
+}
+
+impl std::fmt::Debug for QTensor {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ write!(f, "QTensor[{:?}; {:?}]", self.shape, self.dtype())
+ }
+}
+
+impl QTensor {
+ pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
+ data: Vec<T>,
+ shape: S,
+ ) -> Self {
+ Self {
+ data: Box::new(data),
+ shape: shape.into(),
+ }
+ }
+
+ pub fn dtype(&self) -> GgmlDType {
+ self.data.dtype()
+ }
+
+ pub fn shape(&self) -> &Shape {
+ &self.shape
+ }
+
+ pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
+ let mut f32_data = vec![0f32; self.shape.elem_count()];
+ self.data.to_float(&mut f32_data)?;
+ Tensor::from_vec(f32_data, &self.shape, device)
+ }
+
+ pub fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
+ self.data.matmul_t(mkn, lhs, dst)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct QMatMul(std::sync::Arc<QTensor>);
+
+impl QMatMul {
+ pub fn new(qtensor: std::sync::Arc<QTensor>) -> Self {
+ Self(qtensor)
+ }
+}
+
+impl crate::CustomOp1 for QMatMul {
+ fn name(&self) -> &'static str {
+ "qmatmul"
+ }
+
+ fn cpu_fwd(
+ &self,
+ storage: &crate::CpuStorage,
+ layout: &crate::Layout,
+ ) -> Result<(crate::CpuStorage, Shape)> {
+ if !layout.is_contiguous() {
+ crate::bail!("input tensor is not contiguous {layout:?}")
+ }
+ let src_shape = layout.shape();
+ let (k, n) = self.0.shape.dims2()?;
+ if src_shape.rank() < 2 {
+ crate::bail!("input tensor has only one dimension {layout:?}")
+ }
+ let mut dst_shape = src_shape.dims().to_vec();
+ let last_k = dst_shape.pop().unwrap();
+ if last_k != k {
+ crate::bail!(
+ "input tensor {layout:?} incompatible with {:?}",
+ self.0.shape
+ )
+ }
+ dst_shape.push(n);
+ let dst_shape = Shape::from(dst_shape);
+ let storage = storage.as_slice::<f32>()?;
+ let storage =
+ &storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
+ let mut dst_storage = vec![0f32; dst_shape.elem_count()];
+ self.0.matmul_t(
+ (dst_shape.elem_count() / n, k, n),
+ storage,
+ &mut dst_storage,
+ )?;
+ Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
+ }
+}