summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/quantized/avx.rs5
-rw-r--r--candle-core/src/quantized/k_quants.rs5
-rw-r--r--candle-core/src/quantized/neon.rs29
-rw-r--r--candle-core/src/quantized/simd128.rs4
4 files changed, 2 insertions, 41 deletions
diff --git a/candle-core/src/quantized/avx.rs b/candle-core/src/quantized/avx.rs
index d4b05bb0..5c3ac822 100644
--- a/candle-core/src/quantized/avx.rs
+++ b/candle-core/src/quantized/avx.rs
@@ -50,14 +50,9 @@ pub(crate) unsafe fn mul_sum_i8_pairs_float(x: __m256i, y: __m256i) -> __m256 {
#[inline(always)]
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
- let nb = n / qk;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
- if nb % 2 != 0 {
- crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
- }
-
unsafe {
let mut acc = _mm256_setzero_ps();
for (x, y) in xs.iter().zip(ys.iter()) {
diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs
index b140131e..d16289e6 100644
--- a/candle-core/src/quantized/k_quants.rs
+++ b/candle-core/src/quantized/k_quants.rs
@@ -236,14 +236,9 @@ impl GgmlType for BlockQ4_0 {
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
let qk = QK8_0;
- let nb = n / qk;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
- if nb % 2 != 0 {
- crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
- }
-
// Generic implementation.
let mut sumf = 0f32;
for (xs, ys) in xs.iter().zip(ys.iter()) {
diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs
index 51bd3e66..3cb56229 100644
--- a/candle-core/src/quantized/neon.rs
+++ b/candle-core/src/quantized/neon.rs
@@ -19,42 +19,29 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
- if nb % 2 != 0 {
- crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
- }
unsafe {
let mut sumv0 = vdupq_n_f32(0.0f32);
- let mut sumv1 = vdupq_n_f32(0.0f32);
- for i in (0..nb).step_by(2) {
+ for i in 0..nb {
let x0 = &xs[i];
- let x1 = &xs[i + 1];
let y0 = &ys[i];
- let y1 = &ys[i + 1];
let m4b = vdupq_n_u8(0x0F);
let s8b = vdupq_n_s8(0x8);
let v0_0 = vld1q_u8(x0.qs.as_ptr());
- let v0_1 = vld1q_u8(x1.qs.as_ptr());
// 4-bit -> 8-bit
let v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
let v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
- let v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
- let v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
// sub 8
let v0_0ls = vsubq_s8(v0_0l, s8b);
let v0_0hs = vsubq_s8(v0_0h, s8b);
- let v0_1ls = vsubq_s8(v0_1l, s8b);
- let v0_1hs = vsubq_s8(v0_1h, s8b);
// load y
let v1_0l = vld1q_s8(y0.qs.as_ptr());
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
- let v1_1l = vld1q_s8(y1.qs.as_ptr());
- let v1_1h = vld1q_s8(y1.qs.as_ptr().add(16));
// TODO: Support dotprod when it's available outside of nightly.
let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l));
@@ -62,28 +49,16 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h));
let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
- let pl1l = vmull_s8(vget_low_s8(v0_1ls), vget_low_s8(v1_1l));
- let pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
- let ph1l = vmull_s8(vget_low_s8(v0_1hs), vget_low_s8(v1_1h));
- let ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
-
let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
- let pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
- let ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
sumv0 = vmlaq_n_f32(
sumv0,
vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
x0.d.to_f32() * y0.d.to_f32(),
);
- sumv1 = vmlaq_n_f32(
- sumv1,
- vcvtq_f32_s32(vaddq_s32(pl1, ph1)),
- x1.d.to_f32() * y1.d.to_f32(),
- );
}
- Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1))
+ Ok(vaddvq_f32(sumv0))
}
}
diff --git a/candle-core/src/quantized/simd128.rs b/candle-core/src/quantized/simd128.rs
index f256fdc2..1c8c0f20 100644
--- a/candle-core/src/quantized/simd128.rs
+++ b/candle-core/src/quantized/simd128.rs
@@ -11,10 +11,6 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
- let nb = n / QK8_0;
- if nb % 2 != 0 {
- crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
- }
unsafe {
let mut acc = f32x4_splat(0.0f32);
for (x, y) in xs.iter().zip(ys.iter()) {