diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-15 22:45:53 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-15 22:45:53 +0100 |
commit | ca449f9ee11b892e026972d114c77a0938e1dc0b (patch) | |
tree | fd179ab9ffd5d1a3da740506a091147df9ba39e5 /candle-core/src/quantized/mod.rs | |
parent | b8263aa15cf2d8d0f425e25bae296ea4e96aeb88 (diff) | |
download | candle-ca449f9ee11b892e026972d114c77a0938e1dc0b.tar.gz candle-ca449f9ee11b892e026972d114c77a0938e1dc0b.tar.bz2 candle-ca449f9ee11b892e026972d114c77a0938e1dc0b.zip |
Add quantized tensors. (#458)
* Add quantized tensors.
* Implement the debug trait for QTensor.
* Add the QMatMul custom op.
Diffstat (limited to 'candle-core/src/quantized/mod.rs')
-rw-r--r-- | candle-core/src/quantized/mod.rs | 114 |
1 files changed, 113 insertions, 1 deletions
diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index c7e24592..842b519b 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -1,10 +1,15 @@ -use crate::Result; +use crate::{Device, Result, Shape, Tensor}; pub mod ggml_file; pub mod k_quants; pub use k_quants::GgmlType; +pub struct QTensor { + data: Box<dyn QuantizedType>, + shape: Shape, +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum GgmlDType { F32, @@ -80,3 +85,110 @@ impl GgmlDType { } } } + +// A version of GgmlType without `vec_dot` so that it can be dyn boxed. +pub trait QuantizedType: Send + Sync { + fn dtype(&self) -> GgmlDType; + fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>; + fn to_float(&self, ys: &mut [f32]) -> Result<()>; +} + +impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> { + fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> { + k_quants::matmul(mkn, lhs, self.as_slice(), dst) + } + + fn dtype(&self) -> GgmlDType { + T::DTYPE + } + + fn to_float(&self, ys: &mut [f32]) -> Result<()> { + T::to_float(self.as_slice(), ys) + } +} + +impl std::fmt::Debug for QTensor { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "QTensor[{:?}; {:?}]", self.shape, self.dtype()) + } +} + +impl QTensor { + pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>( + data: Vec<T>, + shape: S, + ) -> Self { + Self { + data: Box::new(data), + shape: shape.into(), + } + } + + pub fn dtype(&self) -> GgmlDType { + self.data.dtype() + } + + pub fn shape(&self) -> &Shape { + &self.shape + } + + pub fn dequantize(&self, device: &Device) -> Result<Tensor> { + let mut f32_data = vec![0f32; self.shape.elem_count()]; + self.data.to_float(&mut f32_data)?; + Tensor::from_vec(f32_data, &self.shape, device) + } + + pub fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> { + self.data.matmul_t(mkn, lhs, dst) + } +} + +#[derive(Debug, Clone)] +pub struct QMatMul(std::sync::Arc<QTensor>); + +impl QMatMul { + pub fn new(qtensor: std::sync::Arc<QTensor>) -> Self { + Self(qtensor) + } +} + +impl crate::CustomOp1 for QMatMul { + fn name(&self) -> &'static str { + "qmatmul" + } + + fn cpu_fwd( + &self, + storage: &crate::CpuStorage, + layout: &crate::Layout, + ) -> Result<(crate::CpuStorage, Shape)> { + if !layout.is_contiguous() { + crate::bail!("input tensor is not contiguous {layout:?}") + } + let src_shape = layout.shape(); + let (k, n) = self.0.shape.dims2()?; + if src_shape.rank() < 2 { + crate::bail!("input tensor has only one dimension {layout:?}") + } + let mut dst_shape = src_shape.dims().to_vec(); + let last_k = dst_shape.pop().unwrap(); + if last_k != k { + crate::bail!( + "input tensor {layout:?} incompatible with {:?}", + self.0.shape + ) + } + dst_shape.push(n); + let dst_shape = Shape::from(dst_shape); + let storage = storage.as_slice::<f32>()?; + let storage = + &storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()]; + let mut dst_storage = vec![0f32; dst_shape.elem_count()]; + self.0.matmul_t( + (dst_shape.elem_count() / n, k, n), + storage, + &mut dst_storage, + )?; + Ok((crate::CpuStorage::F32(dst_storage), dst_shape)) + } +} |