summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/quantized/k_quants.rs4
-rw-r--r--candle-core/src/quantized/neon.rs96
2 files changed, 99 insertions, 1 deletions
diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs
index fec240bb..7b405ec9 100644
--- a/candle-core/src/quantized/k_quants.rs
+++ b/candle-core/src/quantized/k_quants.rs
@@ -1102,7 +1102,11 @@ impl GgmlType for BlockQ4K {
const BLCK_SIZE: usize = QK_K;
type VecDotType = BlockQ8K;
+ #[allow(unreachable_code)]
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
+ #[cfg(target_feature = "neon")]
+ return super::neon::vec_dot_q4k_q8k(n, xs, ys);
+
if n % QK_K != 0 {
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
}
diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs
index 32c93af4..69d616f4 100644
--- a/candle-core/src/quantized/neon.rs
+++ b/candle-core/src/quantized/neon.rs
@@ -1,5 +1,6 @@
-use super::k_quants::{BlockQ4_0, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K};
+use super::k_quants::{BlockQ4K, BlockQ4_0, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K};
use crate::Result;
+use byteorder::{ByteOrder, LittleEndian};
#[allow(unused_imports)]
#[cfg(target_arch = "arm")]
@@ -279,3 +280,96 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
}
Ok(sum)
}
+
+#[inline(always)]
+pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
+ if n % QK_K != 0 {
+ crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
+ }
+ let mut sumf = 0f32;
+ let mut utmp = [0u32; 4];
+ let mut scales = [0u8; 16];
+ let kmask1: u32 = 0x3f3f3f3f;
+ let kmask2: u32 = 0x0f0f0f0f;
+ let kmask3: u32 = 0x03030303;
+
+ unsafe {
+ let m4b = vdupq_n_u8(0xF);
+
+ for (x, y) in xs.iter().zip(ys.iter()) {
+ let d = y.d * x.d.to_f32();
+ let dmin = y.d * x.dmin.to_f32();
+
+ let q8sums = vpaddq_s16(
+ vld1q_s16(y.bsums.as_ptr()),
+ vld1q_s16(y.bsums.as_ptr().add(8)),
+ );
+
+ LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
+
+ let mins8 = vld1_u32(
+ [
+ utmp[1] & kmask1,
+ ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4),
+ ]
+ .as_ptr(),
+ );
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+ utmp[0] &= kmask1;
+
+ let mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
+ let prod = vaddq_s32(
+ vmull_s16(vget_low_s16(q8sums), vget_low_s16(mins)),
+ vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)),
+ );
+ sumf -= dmin * vaddvq_s32(prod) as f32;
+
+ LittleEndian::write_u32_into(&utmp, &mut scales);
+
+ let mut q4 = x.qs.as_ptr();
+ let mut q8 = y.qs.as_ptr();
+
+ let mut sumi1 = 0i32;
+ let mut sumi2 = 0i32;
+
+ for j in 0..QK_K / 64 {
+ let q4bits = vld1q_u8_x2(q4);
+ q4 = q4.add(32);
+ // TODO: dotprod
+ let q8bytes = vld1q_s8_x2(q8);
+ q8 = q8.add(32);
+ let q4bytes = int8x16x2_t(
+ vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)),
+ vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)),
+ );
+ let p0 = vaddq_s16(
+ vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
+ vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
+ );
+ let p1 = vaddq_s16(
+ vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
+ vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
+ );
+ sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * scales[2 * j] as i32;
+
+ let q8bytes = vld1q_s8_x2(q8);
+ q8 = q8.add(32);
+ let q4bytes = int8x16x2_t(
+ vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)),
+ vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)),
+ );
+ let p2 = vaddq_s16(
+ vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
+ vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
+ );
+ let p3 = vaddq_s16(
+ vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
+ vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
+ );
+ sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * scales[2 * j + 1] as i32;
+ }
+ sumf += d * (sumi1 + sumi2) as f32;
+ }
+ }
+ Ok(sumf)
+}