summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/ggml.rs159
-rw-r--r--candle-core/tests/ggml_tests.rs33
2 files changed, 119 insertions, 73 deletions
diff --git a/candle-core/src/ggml.rs b/candle-core/src/ggml.rs
index 0b3dee04..3a41eeec 100644
--- a/candle-core/src/ggml.rs
+++ b/candle-core/src/ggml.rs
@@ -15,18 +15,24 @@ pub const QK5_1: usize = 32;
pub const QK8_0: usize = 32;
pub const QK8_1: usize = 32;
-pub trait GgmlType: Sized {
+pub trait GgmlType: Sized + Clone {
const DTYPE: GgmlDType;
const BLCK_SIZE: usize;
+ type VecDotType: GgmlType;
+ // This is only safe for types that include immediate values such as float/int/...
+ fn zeros() -> Self {
+ unsafe { std::mem::MaybeUninit::zeroed().assume_init() }
+ }
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()>;
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()>;
- type VecDotType: GgmlType;
- // Dot product used as a building block for quantized mat-mul.
+ /// Dot product used as a building block for quantized mat-mul.
+ /// n is the number of elements to be considered.
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32>;
}
+#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQ4_0 {
d: f16,
@@ -34,6 +40,7 @@ pub struct BlockQ4_0 {
}
const _: () = assert!(std::mem::size_of::<BlockQ4_0>() == 18);
+#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQ4_1 {
d: f16,
@@ -42,6 +49,7 @@ pub struct BlockQ4_1 {
}
const _: () = assert!(std::mem::size_of::<BlockQ4_1>() == 20);
+#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQ5_0 {
d: f16,
@@ -50,6 +58,7 @@ pub struct BlockQ5_0 {
}
const _: () = assert!(std::mem::size_of::<BlockQ5_0>() == 22);
+#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQ5_1 {
d: f16,
@@ -59,6 +68,7 @@ pub struct BlockQ5_1 {
}
const _: () = assert!(std::mem::size_of::<BlockQ5_1>() == 24);
+#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQ8_0 {
d: f16,
@@ -66,6 +76,7 @@ pub struct BlockQ8_0 {
}
const _: () = assert!(std::mem::size_of::<BlockQ8_0>() == 34);
+#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQ8_1 {
d: f16,
@@ -74,6 +85,7 @@ pub struct BlockQ8_1 {
}
const _: () = assert!(std::mem::size_of::<BlockQ8_1>() == 36);
+#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQ2K {
scales: [u8; QK_K / 16],
@@ -83,6 +95,7 @@ pub struct BlockQ2K {
}
const _: () = assert!(QK_K / 16 + QK_K / 4 + 2 * 2 == std::mem::size_of::<BlockQ2K>());
+#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQ3K {
hmask: [u8; QK_K / 8],
@@ -92,6 +105,7 @@ pub struct BlockQ3K {
}
const _: () = assert!(QK_K / 8 + QK_K / 4 + 12 + 2 == std::mem::size_of::<BlockQ3K>());
+#[derive(Debug, Clone, PartialEq)]
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/k_quants.h#L82
#[repr(C)]
pub struct BlockQ4K {
@@ -102,6 +116,7 @@ pub struct BlockQ4K {
}
const _: () = assert!(QK_K / 2 + K_SCALE_SIZE + 2 * 2 == std::mem::size_of::<BlockQ4K>());
+#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQ5K {
d: f16,
@@ -113,6 +128,7 @@ pub struct BlockQ5K {
const _: () =
assert!(QK_K / 8 + QK_K / 2 + 2 * 2 + K_SCALE_SIZE == std::mem::size_of::<BlockQ5K>());
+#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQ6K {
ql: [u8; QK_K / 2],
@@ -122,6 +138,7 @@ pub struct BlockQ6K {
}
const _: () = assert!(3 * QK_K / 4 + QK_K / 16 + 2 == std::mem::size_of::<BlockQ6K>());
+#[derive(Debug, Clone, PartialEq)]
#[repr(C)]
pub struct BlockQ8K {
d: f32,
@@ -535,8 +552,41 @@ impl GgmlType for BlockQ4_0 {
Ok(())
}
- fn from_float(_: &[f32], _: &mut [Self]) -> Result<()> {
- todo!()
+ fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
+ // quantize_row_q4_0
+ let qk = Self::BLCK_SIZE;
+ let k = xs.len();
+ if k % qk != 0 {
+ crate::bail!("{k} is not divisible by {}", qk);
+ };
+ let nb = k / qk;
+ if ys.len() != nb {
+ crate::bail!("size mismatch {} {} {}", xs.len(), ys.len(), qk,)
+ }
+ for (i, ys) in ys.iter_mut().enumerate() {
+ let mut amax = 0f32;
+ let mut max = 0f32;
+
+ let xs = &xs[i * qk..(i + 1) * qk];
+ for &x in xs.iter() {
+ if amax < x.abs() {
+ amax = x.abs();
+ max = x;
+ }
+ }
+ let d = max / -8.0;
+ let id = if d != 0f32 { 1. / d } else { 0. };
+ ys.d = f16::from_f32(d);
+
+ for (j, q) in ys.qs.iter_mut().enumerate() {
+ let x0 = xs[j] * id;
+ let x1 = xs[qk / 2 + j] * id;
+ let xi0 = u8::min(15, (x0 + 8.5) as u8);
+ let xi1 = u8::min(15, (x1 + 8.5) as u8);
+ *q = xi0 | (xi1 << 4)
+ }
+ }
+ Ok(())
}
// https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L2361C10-L2361C122
@@ -555,9 +605,9 @@ impl GgmlType for BlockQ4_0 {
for i in 0..nb {
let mut sum_i = 0;
for j in 0..qk / 2 {
- let v0 = (xs[i].qs[j] & 0x0F) - 8;
- let v1 = (xs[i].qs[j] >> 4) - 8;
- sum_i += v0 * ys[i].qs[j] + v1 * ys[i].qs[j + qk / 2]
+ let v0 = (xs[i].qs[j] & 0x0F) as i32 - 8;
+ let v1 = (xs[i].qs[j] >> 4) as i32 - 8;
+ sum_i += v0 * ys[i].qs[j] as i32 + v1 * ys[i].qs[j + qk / 2] as i32
}
sumf += sum_i as f32 * f16::to_f32(xs[i].d) * f16::to_f32(ys[i].d)
}
@@ -591,7 +641,7 @@ impl GgmlType for BlockQ8_0 {
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
// quantize_row_q8_0
- let k = ys.len();
+ let k = xs.len();
if k % Self::BLCK_SIZE != 0 {
crate::bail!("{k} is not divisible by {}", Self::BLCK_SIZE);
};
@@ -608,7 +658,7 @@ impl GgmlType for BlockQ8_0 {
let mut amax = 0f32;
let xs = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE];
for &x in xs.iter() {
- amax = amax.max(x)
+ amax = amax.max(x.abs())
}
let d = amax / ((1 << 7) - 1) as f32;
let id = if d != 0f32 { 1. / d } else { 0. };
@@ -644,73 +694,36 @@ impl GgmlType for BlockQ8_1 {
}
}
-const BLCK0: usize = 16;
-const BLCK1: usize = 16;
-
-// This implementation is in-line with the ggml one and keeps the same variable names.
// https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L10605
-pub fn forward_mul_mat<T: GgmlType>(src0: &[T], src1: &[f32], dst: &mut [f32]) -> Result<()> {
- // TODO: Use the proper sizes here.
- let (ne00, ne01, ne02, ne03) = (1, 1, 1, 1);
- let (ne10, ne11, ne12, ne13) = (1, 1, 1, 1);
- // The strides are in bytes in ggml, however we use the number of elements in candle.
- let (_, nb1, nb2, nb3) = (1, 1, 1, 1);
- let (_, nb01, nb02, nb03) = (1, 1, 1, 1);
- let (_, nb11, nb12, nb13) = (1, 1, 1, 1);
-
- let nr0 = ne01; // src0 rows
- let nr1 = ne11 * ne12 * ne13;
-
- // TODO: Either add multi-threading or remove these bits.
- let ir010 = 0;
- let ir011 = nr0;
- let ir110 = 0;
- let ir111 = nr1;
- let r2 = ne12 / ne02;
- let r3 = ne13 / ne03;
+pub fn matmul<T: GgmlType>(
+ mkn: (usize, usize, usize),
+ lhs: &[f32],
+ rhs_t: &[T],
+ dst: &mut [f32],
+) -> Result<()> {
+ let (m, k, n) = mkn;
+ if m * k != lhs.len() {
+ crate::bail!("unexpected lhs length {} {mkn:?}", lhs.len());
+ }
+ let k_in_lhs_blocks = (k + T::BLCK_SIZE - 1) / T::BLCK_SIZE;
+ let k_in_rhs_blocks = (k + T::VecDotType::BLCK_SIZE - 1) / T::VecDotType::BLCK_SIZE;
+ // TODO: Do not make this copy if the DotType is f32.
// TODO: Pre-allocate this.
- let wdata = &mut [];
- if ne10 % T::BLCK_SIZE != 0 {
- crate::bail!(
- "forward_mul_mat: ne10 {ne10} is not divisible by block size {}",
- T::BLCK_SIZE
- )
+ let mut lhs_b = vec![T::VecDotType::zeros(); m * k_in_lhs_blocks];
+ for row_idx in 0..m {
+ let lhs_b = &mut lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks];
+ let lhs = &lhs[row_idx * k..(row_idx + 1) * k];
+ T::VecDotType::from_float(lhs, lhs_b)?
}
- let row_size = ne10 / T::BLCK_SIZE;
- for i13 in 0..ne13 {
- for i12 in 0..ne12 {
- for i11 in 0..ne11 {
- let wdata_idx = i11 + i12 * ne11 + i13 * ne11 * ne12;
- let wdata = &mut wdata[wdata_idx..wdata_idx + row_size];
- let src1 = &src1[i13 * nb13 + i12 * nb12 + i11 * nb11..];
- T::VecDotType::from_float(src1, wdata)?
- }
- }
- }
- for iir1 in (ir110..ir111).step_by(BLCK1) {
- for iir0 in (ir010..ir011).step_by(BLCK0) {
- for ir1 in iir1..usize::min(iir1 + BLCK1, ir111) {
- let i13 = ir1 / (ne12 * ne11);
- let i12 = (ir1 - i13 * ne12 * ne11) / ne11;
- let i11 = ir1 - i13 * ne12 * ne11 - i12 * ne11;
-
- let i03 = i13 / r3;
- let i02 = i12 / r2;
-
- let i1 = i11;
- let i2 = i12;
- let i3 = i13;
-
- let src0_row = &src0[i02 * nb02 + i03 * nb03..];
- let src1_col = &wdata[(i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size..];
- let dst_col = &mut dst[i1 * nb1 + i2 * nb2 + i3 * nb3..];
- for ir0 in iir0..usize::min(iir0 + BLCK0, ir011) {
- let src0_row = &src0_row[ir0 * nb01..];
- let v = T::vec_dot(ne00, src0_row, src1_col)?;
- dst_col[ir0 - iir0] += v
- }
- }
+ let lhs_b = lhs_b.as_slice();
+
+ for row_idx in 0..m {
+ let lhs_row = &lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks];
+ let dst_row = &mut dst[row_idx * n..(row_idx + 1) * n];
+ for (col_idx, dst) in dst_row.iter_mut().enumerate() {
+ let rhs_col = &rhs_t[col_idx * k_in_rhs_blocks..(col_idx + 1) * k_in_rhs_blocks];
+ *dst = T::vec_dot(k, rhs_col, lhs_row)?;
}
}
Ok(())
diff --git a/candle-core/tests/ggml_tests.rs b/candle-core/tests/ggml_tests.rs
new file mode 100644
index 00000000..d976ad99
--- /dev/null
+++ b/candle-core/tests/ggml_tests.rs
@@ -0,0 +1,33 @@
+use candle_core::{ggml, Device, Result, Tensor};
+use ggml::GgmlType;
+
+#[test]
+fn ggml_matmul() -> Result<()> {
+ let cpu = &Device::Cpu;
+ let (m, k, n) = (3, 64, 4);
+ let lhs = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
+ let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?;
+ let mut dst = vec![42.; 3 * 4];
+ let mut rhs_t = vec![ggml::BlockQ4_0::zeros(); 8];
+ let rhs = (0..(k * n)).map(|v| v as f32).collect::<Vec<_>>();
+ let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?;
+ ggml::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
+ ggml::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
+ assert_eq!(
+ dst,
+ &[
+ 85120.43, 214561.61, 345454.9, 474748.1, 213474.94, 604465.25, 1000686.4, 1388317.3,
+ 341875.88, 994283.0, 1655708.8, 2301518.3
+ ]
+ );
+ let mm = tensor_lhs.matmul(&tensor_rhs)?;
+ assert_eq!(
+ mm.to_vec2::<f32>()?,
+ &[
+ [85344.0, 214368.0, 343392.0, 472416.0],
+ [214368.0, 605536.0, 996704.0, 1387872.0],
+ [343392.0, 996704.0, 1650016.0, 2303328.0]
+ ]
+ );
+ Ok(())
+}