summaryrefslogtreecommitdiff
path: root/candle-core/src/quantized/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/quantized/mod.rs')
-rw-r--r--candle-core/src/quantized/mod.rs114
1 files changed, 113 insertions, 1 deletions
diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs
index c7e24592..842b519b 100644
--- a/candle-core/src/quantized/mod.rs
+++ b/candle-core/src/quantized/mod.rs
@@ -1,10 +1,15 @@
-use crate::Result;
+use crate::{Device, Result, Shape, Tensor};
pub mod ggml_file;
pub mod k_quants;
pub use k_quants::GgmlType;
+pub struct QTensor {
+ data: Box<dyn QuantizedType>,
+ shape: Shape,
+}
+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GgmlDType {
F32,
@@ -80,3 +85,110 @@ impl GgmlDType {
}
}
}
+
+// A version of GgmlType without `vec_dot` so that it can be dyn boxed.
+pub trait QuantizedType: Send + Sync {
+ fn dtype(&self) -> GgmlDType;
+ fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
+ fn to_float(&self, ys: &mut [f32]) -> Result<()>;
+}
+
+impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
+ fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
+ k_quants::matmul(mkn, lhs, self.as_slice(), dst)
+ }
+
+ fn dtype(&self) -> GgmlDType {
+ T::DTYPE
+ }
+
+ fn to_float(&self, ys: &mut [f32]) -> Result<()> {
+ T::to_float(self.as_slice(), ys)
+ }
+}
+
+impl std::fmt::Debug for QTensor {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ write!(f, "QTensor[{:?}; {:?}]", self.shape, self.dtype())
+ }
+}
+
+impl QTensor {
+ pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
+ data: Vec<T>,
+ shape: S,
+ ) -> Self {
+ Self {
+ data: Box::new(data),
+ shape: shape.into(),
+ }
+ }
+
+ pub fn dtype(&self) -> GgmlDType {
+ self.data.dtype()
+ }
+
+ pub fn shape(&self) -> &Shape {
+ &self.shape
+ }
+
+ pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
+ let mut f32_data = vec![0f32; self.shape.elem_count()];
+ self.data.to_float(&mut f32_data)?;
+ Tensor::from_vec(f32_data, &self.shape, device)
+ }
+
+ pub fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
+ self.data.matmul_t(mkn, lhs, dst)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct QMatMul(std::sync::Arc<QTensor>);
+
+impl QMatMul {
+ pub fn new(qtensor: std::sync::Arc<QTensor>) -> Self {
+ Self(qtensor)
+ }
+}
+
+impl crate::CustomOp1 for QMatMul {
+ fn name(&self) -> &'static str {
+ "qmatmul"
+ }
+
+ fn cpu_fwd(
+ &self,
+ storage: &crate::CpuStorage,
+ layout: &crate::Layout,
+ ) -> Result<(crate::CpuStorage, Shape)> {
+ if !layout.is_contiguous() {
+ crate::bail!("input tensor is not contiguous {layout:?}")
+ }
+ let src_shape = layout.shape();
+ let (k, n) = self.0.shape.dims2()?;
+ if src_shape.rank() < 2 {
+ crate::bail!("input tensor has only one dimension {layout:?}")
+ }
+ let mut dst_shape = src_shape.dims().to_vec();
+ let last_k = dst_shape.pop().unwrap();
+ if last_k != k {
+ crate::bail!(
+ "input tensor {layout:?} incompatible with {:?}",
+ self.0.shape
+ )
+ }
+ dst_shape.push(n);
+ let dst_shape = Shape::from(dst_shape);
+ let storage = storage.as_slice::<f32>()?;
+ let storage =
+ &storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
+ let mut dst_storage = vec![0f32; dst_shape.elem_count()];
+ self.0.matmul_t(
+ (dst_shape.elem_count() / n, k, n),
+ storage,
+ &mut dst_storage,
+ )?;
+ Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
+ }
+}