summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-01-07 20:21:49 +0100
committerGitHub <noreply@github.com>2024-01-07 20:21:49 +0100
commit0eb90ed7831d451e2e420ecd158151b44dc5b2ba (patch)
tree19da338c2598680addd1e5f65b41d827b03a7ca9
parent89b5a068585b73193d2004a7293d5b2fa6c30bfd (diff)
downloadcandle-0eb90ed7831d451e2e420ecd158151b44dc5b2ba.tar.gz
candle-0eb90ed7831d451e2e420ecd158151b44dc5b2ba.tar.bz2
candle-0eb90ed7831d451e2e420ecd158151b44dc5b2ba.zip
Simpler repro for the neon optimization issue + bugfix (#1544)
* Simpler repro for the neon optimization issue. * Bugfix for q4k. * Improve the fix, share the dot-prod bit. * Clippy fixes. * Fix for q6k. * Also fix for q2k. * Use the new shared dotprod. * Add more testing.
-rw-r--r--candle-core/src/quantized/neon.rs208
-rw-r--r--candle-core/tests/quantized_tests.rs57
2 files changed, 97 insertions, 168 deletions
diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs
index 3cb56229..c4d5d6f4 100644
--- a/candle-core/src/quantized/neon.rs
+++ b/candle-core/src/quantized/neon.rs
@@ -13,6 +13,14 @@ use core::arch::arm::*;
use core::arch::aarch64::*;
#[inline(always)]
+unsafe fn vdotq_s32(a: int8x16_t, b: int8x16_t) -> int32x4_t {
+ // TODO: dotprod
+ let p0 = vmull_s8(vget_low_s8(a), vget_low_s8(b));
+ let p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
+ vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1))
+}
+
+#[inline(always)]
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
let nb = n / qk;
@@ -43,15 +51,8 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
let v1_0l = vld1q_s8(y0.qs.as_ptr());
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
- // TODO: Support dotprod when it's available outside of nightly.
- let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l));
- let pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
- let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h));
- let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
-
- let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
- let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
-
+ let pl0 = vdotq_s32(v0_0ls, v1_0l);
+ let ph0 = vdotq_s32(v0_0hs, v1_0h);
sumv0 = vmlaq_n_f32(
sumv0,
vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
@@ -82,14 +83,8 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
let y0_0 = vld1q_s8(y0.qs.as_ptr());
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
- // TODO dotprod once this is the intrinsics are.
- let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0));
- let p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
- let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1));
- let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
-
- let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
- let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
+ let p0 = vdotq_s32(x0_0, y0_0);
+ let p1 = vdotq_s32(x0_1, y0_1);
sumv0 = vmlaq_n_f32(
sumv0,
@@ -118,10 +113,7 @@ pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Res
for i in (0..QK_K).step_by(16) {
let xs = vld1q_s8(xs.add(i));
let ys = vld1q_s8(ys.add(i));
- let xy_lo = vmull_s8(vget_low_s8(xs), vget_low_s8(ys));
- let xy_up = vmull_s8(vget_high_s8(xs), vget_high_s8(ys));
-
- let xy = vaddq_s32(vpaddlq_s16(xy_lo), vpaddlq_s16(xy_up));
+ let xy = vdotq_s32(xs, ys);
sum_i = vaddq_s32(sum_i, xy)
}
sumf += vaddvq_s32(sum_i) as f32 * scale
@@ -191,30 +183,16 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2));
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3));
- // TODO: dotprod
-
- let p0 = vaddq_s16(
- vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
- vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
- );
- let p1 = vaddq_s16(
- vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
- vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
- );
+ let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
+ let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
- isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1;
+ isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
scale = scale.add(2);
- let p2 = vaddq_s16(
- vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)),
- vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
- );
- let p3 = vaddq_s16(
- vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
- vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
- );
+ let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
+ let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
- isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1;
+ isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
scale = scale.add(2);
let q8bytes = vld1q_s8_x4(q8);
@@ -234,29 +212,16 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2));
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3));
- // TODO: dotprod case.
- let p0 = vaddq_s16(
- vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
- vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
- );
- let p1 = vaddq_s16(
- vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
- vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
- );
+ let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
+ let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
- isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1;
+ isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
scale = scale.add(2);
- let p2 = vaddq_s16(
- vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)),
- vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
- );
- let p3 = vaddq_s16(
- vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
- vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
- );
+ let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
+ let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
- isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1;
+ isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
scale = scale.add(2);
}
sum += d_all * y.d * ((isum - 32 * isum_mins) as f32);
@@ -333,28 +298,14 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res
let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2));
let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3));
- // TODO: dotprod
-
- let p0 = vaddq_s16(
- vmull_s8(vget_low_s8(q5bytes_0), vget_low_s8(q8bytes.0)),
- vmull_s8(vget_high_s8(q5bytes_0), vget_high_s8(q8bytes.0)),
- );
- let p1 = vaddq_s16(
- vmull_s8(vget_low_s8(q5bytes_1), vget_low_s8(q8bytes.1)),
- vmull_s8(vget_high_s8(q5bytes_1), vget_high_s8(q8bytes.1)),
- );
- sumi += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * *scales as i32;
+ let p0 = vdotq_s32(q5bytes_0, q8bytes.0);
+ let p1 = vdotq_s32(q5bytes_1, q8bytes.1);
+ sumi += vaddvq_s32(vaddq_s32(p0, p1)) * *scales as i32;
scales = scales.add(1);
- let p2 = vaddq_s16(
- vmull_s8(vget_low_s8(q5bytes_2), vget_low_s8(q8bytes.2)),
- vmull_s8(vget_high_s8(q5bytes_2), vget_high_s8(q8bytes.2)),
- );
- let p3 = vaddq_s16(
- vmull_s8(vget_low_s8(q5bytes_3), vget_low_s8(q8bytes.3)),
- vmull_s8(vget_high_s8(q5bytes_3), vget_high_s8(q8bytes.3)),
- );
- sumi += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * *scales as i32;
+ let p2 = vdotq_s32(q5bytes_2, q8bytes.2);
+ let p3 = vdotq_s32(q5bytes_3, q8bytes.3);
+ sumi += vaddvq_s32(vaddq_s32(p2, p3)) * *scales as i32;
scales = scales.add(1);
}
sumf += d * sumi as f32 - dmin * sumi_mins as f32;
@@ -417,22 +368,15 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
for j in 0..QK_K / 64 {
let q4bits = vld1q_u8_x2(q4);
q4 = q4.add(32);
- // TODO: dotprod
let q8bytes = vld1q_s8_x2(q8);
q8 = q8.add(32);
let q4bytes = int8x16x2_t(
vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)),
vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)),
);
- let p0 = vaddq_s16(
- vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
- vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
- );
- let p1 = vaddq_s16(
- vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
- vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
- );
- sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * scales[2 * j] as i32;
+ let p0 = vdotq_s32(q4bytes.0, q8bytes.0);
+ let p1 = vdotq_s32(q4bytes.1, q8bytes.1);
+ sumi1 += vaddvq_s32(vaddq_s32(p0, p1)) * scales[2 * j] as i32;
let q8bytes = vld1q_s8_x2(q8);
q8 = q8.add(32);
@@ -440,15 +384,9 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)),
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)),
);
- let p2 = vaddq_s16(
- vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
- vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
- );
- let p3 = vaddq_s16(
- vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
- vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
- );
- sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * scales[2 * j + 1] as i32;
+ let p2 = vdotq_s32(q4bytes.0, q8bytes.0);
+ let p3 = vdotq_s32(q4bytes.1, q8bytes.1);
+ sumi2 += vaddvq_s32(vaddq_s32(p2, p3)) * scales[2 * j + 1] as i32;
}
sumf += d * (sumi1 + sumi2) as f32;
}
@@ -526,27 +464,14 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
vreinterpretq_s8_u8(q3h_3),
);
- // TODO: dotprod
- let p0 = vaddq_s16(
- vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_1.0)),
- vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_1.0)),
- );
- let p1 = vaddq_s16(
- vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_1.1)),
- vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_1.1)),
- );
- let p2 = vaddq_s16(
- vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_1.2)),
- vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_1.2)),
- );
- let p3 = vaddq_s16(
- vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_1.3)),
- vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_1.3)),
- );
- isum += vaddvq_s16(p0) as i32 * *scale as i32
- + vaddvq_s16(p1) as i32 * *scale.add(1) as i32
- + vaddvq_s16(p2) as i32 * *scale.add(2) as i32
- + vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
+ let p0 = vdotq_s32(q3bytes_0, q8bytes_1.0);
+ let p1 = vdotq_s32(q3bytes_1, q8bytes_1.1);
+ let p2 = vdotq_s32(q3bytes_2, q8bytes_1.2);
+ let p3 = vdotq_s32(q3bytes_3, q8bytes_1.3);
+ isum += vaddvq_s32(p0) * *scale as i32
+ + vaddvq_s32(p1) * *scale.add(1) as i32
+ + vaddvq_s32(p2) * *scale.add(2) as i32
+ + vaddvq_s32(p3) * *scale.add(3) as i32;
scale = scale.add(4);
let q3h_0 = vbicq_u8(m2, qhbits.0);
@@ -571,27 +496,14 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
vreinterpretq_s8_u8(q3h_3),
);
- // TODO: dotprod
- let p0 = vaddq_s16(
- vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_2.0)),
- vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_2.0)),
- );
- let p1 = vaddq_s16(
- vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_2.1)),
- vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_2.1)),
- );
- let p2 = vaddq_s16(
- vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_2.2)),
- vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_2.2)),
- );
- let p3 = vaddq_s16(
- vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_2.3)),
- vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_2.3)),
- );
- isum += vaddvq_s16(p0) as i32 * *scale as i32
- + vaddvq_s16(p1) as i32 * *scale.add(1) as i32
- + vaddvq_s16(p2) as i32 * *scale.add(2) as i32
- + vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
+ let p0 = vdotq_s32(q3bytes_0, q8bytes_2.0);
+ let p1 = vdotq_s32(q3bytes_1, q8bytes_2.1);
+ let p2 = vdotq_s32(q3bytes_2, q8bytes_2.2);
+ let p3 = vdotq_s32(q3bytes_3, q8bytes_2.3);
+ isum += vaddvq_s32(p0) * *scale as i32
+ + vaddvq_s32(p1) * *scale.add(1) as i32
+ + vaddvq_s32(p2) * *scale.add(2) as i32
+ + vaddvq_s32(p3) * *scale.add(3) as i32;
scale = scale.add(4);
if j == 0 {
@@ -649,7 +561,6 @@ pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Res
let mut is = 0usize;
// TODO: dotprod
-
for _j in 0..QK_K / 128 {
let q2bits = vld1q_u8_x2(q2);
q2 = q2.add(32);
@@ -696,14 +607,7 @@ unsafe fn multiply_accum_with_scale(
q2bytes: int8x16x2_t,
q8bytes: int8x16x2_t,
) -> i32 {
- let p1 = vaddq_s16(
- vmull_s8(vget_low_s8(q2bytes.0), vget_low_s8(q8bytes.0)),
- vmull_s8(vget_high_s8(q2bytes.0), vget_high_s8(q8bytes.0)),
- );
- let p2 = vaddq_s16(
- vmull_s8(vget_low_s8(q2bytes.1), vget_low_s8(q8bytes.1)),
- vmull_s8(vget_high_s8(q2bytes.1), vget_high_s8(q8bytes.1)),
- );
- vaddvq_s16(p1) as i32 * aux[is + index] as i32
- + vaddvq_s16(p2) as i32 * aux[is + 1 + index] as i32
+ let p1 = vdotq_s32(q2bytes.0, q8bytes.0);
+ let p2 = vdotq_s32(q2bytes.1, q8bytes.1);
+ vaddvq_s32(p1) * aux[is + index] as i32 + vaddvq_s32(p2) * aux[is + 1 + index] as i32
}
diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs
index 716cca8d..e7a2ea7f 100644
--- a/candle-core/tests/quantized_tests.rs
+++ b/candle-core/tests/quantized_tests.rs
@@ -1,4 +1,5 @@
use candle_core::{
+ bail,
quantized::{self, GgmlDType},
test_utils::to_vec2_round,
Device, Module, Result, Tensor,
@@ -265,7 +266,8 @@ fn compare_with_error(values: &[f32], expected: &[f32], tolerance: f32) {
}
}
-/// Creates a vector simillarly to the one used in GGML unit tests: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30
+/// Creates a vector similar to the ones used in GGML unit tests:
+/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30
fn create_ggml_like_vector(offset: f32) -> Vec<f32> {
(0..GGML_TEST_SIZE)
.map(|i| 0.1 + 2.0 * (i as f32 + offset).cos())
@@ -284,14 +286,15 @@ fn calculate_rmse(a: &[f32], b: &[f32]) -> f32 {
sum / a.len() as f32
}
-/// Mirrores the GGML quanitzation unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50
+/// Similar to the GGML quantization unit test:
+/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50
fn ggml_quantization_error_test<T: GgmlType>(max_error: f32) -> Result<()> {
let src = create_ggml_like_vector(0.0);
let mut dst = vec![0.0; GGML_TEST_SIZE];
let _quant = quantize_roundtrip::<T>(src.as_slice(), dst.as_mut_slice())?;
let error = calculate_rmse(src.as_slice(), dst.as_slice());
if error > max_error {
- candle_core::bail!(
+ bail!(
"Quantization error {} exceeds max error {}",
error,
max_error
@@ -487,54 +490,66 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result<f32> {
GgmlDType::Q5K => 0.000740,
GgmlDType::Q6K => 0.000952,
GgmlDType::Q4_0 => 0.001143,
- GgmlDType::Q4_1 => 0.007784,
+ GgmlDType::Q4_1 => 0.008,
GgmlDType::Q5_0 => 0.001353,
- GgmlDType::Q5_1 => 0.001363,
+ GgmlDType::Q5_1 => 0.00149,
GgmlDType::Q8_0 => 0.000092,
// Not from the ggml repo.
GgmlDType::Q8K => 0.00065,
- _ => candle_core::bail!("No GGML results for quantization type {dtype:?}",),
+ _ => bail!("No GGML results for quantization type {dtype:?}",),
};
Ok(err)
}
-/// Mirrores the GGML matmul unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91
+/// Similar to the GGML matmul unit test:
+/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91
fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
let a = create_ggml_like_vector(0.0);
let b = create_ggml_like_vector(1.0);
+ ggml_matmul_error_test_::<T>(a.as_slice(), b.as_slice(), 1.0)?;
+ // Another example that is more likely to trigger the overflow reported in #1526
+ let a = (0..GGML_TEST_SIZE)
+ .map(|i| i as f32 / GGML_TEST_SIZE as f32)
+ .collect::<Vec<_>>();
+ let b = (0..GGML_TEST_SIZE)
+ .map(|i| i as f32 / GGML_TEST_SIZE as f32)
+ .collect::<Vec<_>>();
+ ggml_matmul_error_test_::<T>(a.as_slice(), b.as_slice(), 2.0)?;
+ Ok(())
+}
+
+fn ggml_matmul_error_test_<T: GgmlType>(a: &[f32], b: &[f32], err_m: f32) -> Result<()> {
let length = a.len();
let mut a_quant = vec![T::zeros(); length / T::BLCK_SIZE];
let mut b_quant = vec![T::VecDotType::zeros(); length / T::VecDotType::BLCK_SIZE];
- T::from_float(&a, &mut a_quant)?;
- T::VecDotType::from_float(&b, &mut b_quant)?;
+ T::from_float(a, &mut a_quant)?;
+ T::VecDotType::from_float(b, &mut b_quant)?;
let result = T::vec_dot(length, &a_quant, &b_quant)?;
let result_unopt = T::vec_dot_unopt(length, &a_quant, &b_quant)?;
- let reference_result = vec_dot_reference(&a, &b);
+ let reference_result = vec_dot_reference(a, b);
if (result - result_unopt).abs() / length as f32 > 1e-6 {
- candle_core::bail!(
+ bail!(
"the opt and unopt vec-dot returned different values, opt {result}, unopt {result_unopt}"
)
}
let error = (result - reference_result).abs() / length as f32;
- let ggml_error = ggml_reference_matmul_error(T::DTYPE)?;
+ let ggml_error = ggml_reference_matmul_error(T::DTYPE)? * err_m;
if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR {
- candle_core::bail!(
- "Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",
- );
+ bail!("Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",);
}
// We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML
// => we use a slightly higher error threshold
const ERROR_LENIENCY: f32 = 0.00001;
if error - ERROR_LENIENCY > ggml_error {
- candle_core::bail!(
+ bail!(
"Dot product error {} exceeds ggml reference error {}",
error,
ggml_error
@@ -543,6 +558,16 @@ fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
Ok(())
}
+#[test]
+fn quantized_mm() -> Result<()> {
+ ggml_matmul_error_test::<k_quants::BlockQ4_0>()?;
+ ggml_matmul_error_test::<k_quants::BlockQ4_1>()?;
+ ggml_matmul_error_test::<k_quants::BlockQ5_0>()?;
+ ggml_matmul_error_test::<k_quants::BlockQ5_1>()?;
+ ggml_matmul_error_test::<k_quants::BlockQ8_0>()?;
+ Ok(())
+}
+
/// generates random tensors of size `m x k` and `n x k` and calculates their expected matrix multiplication result.
fn get_random_tensors(
m: usize,