summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/quantized/k_quants.rs9
-rw-r--r--candle-core/src/quantized/neon.rs368
2 files changed, 369 insertions, 8 deletions
diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs
index e7404529..65fd6a6e 100644
--- a/candle-core/src/quantized/k_quants.rs
+++ b/candle-core/src/quantized/k_quants.rs
@@ -676,6 +676,9 @@ impl GgmlType for BlockQ2K {
#[cfg(target_feature = "avx")]
return super::avx::vec_dot_q2k_q8k(n, xs, ys);
+ #[cfg(target_feature = "neon")]
+ return super::neon::vec_dot_q2k_q8k(n, xs, ys);
+
if n % QK_K != 0 {
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
}
@@ -843,6 +846,9 @@ impl GgmlType for BlockQ3K {
#[cfg(target_feature = "avx")]
return super::avx::vec_dot_q3k_q8k(n, xs, ys);
+ #[cfg(target_feature = "neon")]
+ return super::neon::vec_dot_q3k_q8k(n, xs, ys);
+
if n % QK_K != 0 {
crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}")
}
@@ -1301,6 +1307,9 @@ impl GgmlType for BlockQ5K {
#[cfg(target_feature = "avx")]
return super::avx::vec_dot_q5k_q8k(n, xs, ys);
+ #[cfg(target_feature = "neon")]
+ return super::neon::vec_dot_q5k_q8k(n, xs, ys);
+
if n % QK_K != 0 {
crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}")
}
diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs
index 69d616f4..7f76dadc 100644
--- a/candle-core/src/quantized/neon.rs
+++ b/candle-core/src/quantized/neon.rs
@@ -1,4 +1,6 @@
-use super::k_quants::{BlockQ4K, BlockQ4_0, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K};
+use super::k_quants::{
+ BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K,
+};
use crate::Result;
use byteorder::{ByteOrder, LittleEndian};
@@ -282,6 +284,104 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
}
#[inline(always)]
+pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Result<f32> {
+ if n % QK_K != 0 {
+ crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}")
+ }
+ let mut sumf = 0f32;
+ let mut utmp = [0u32; 4];
+ const KMASK1: u32 = 0x3f3f3f3f;
+ const KMASK2: u32 = 0x0f0f0f0f;
+ const KMASK3: u32 = 0x03030303;
+
+ unsafe {
+ let m4b = vdupq_n_u8(0xF);
+ let mone = vdupq_n_u8(1);
+ let mtwo = vdupq_n_u8(2);
+
+ for (x, y) in xs.iter().zip(ys.iter()) {
+ let d = y.d * x.d.to_f32();
+ let dmin = y.d * x.dmin.to_f32();
+
+ let q8sums = vpaddq_s16(
+ vld1q_s16(y.bsums.as_ptr()),
+ vld1q_s16(y.bsums.as_ptr().add(8)),
+ );
+
+ LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
+
+ utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
+ let uaux = utmp[1] & KMASK1;
+ utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
+ utmp[2] = uaux;
+ utmp[0] &= KMASK1;
+
+ let mins8 = vld1_u8((utmp.as_ptr() as *const u8).add(8));
+ let mins = vreinterpretq_s16_u16(vmovl_u8(mins8));
+ let prod = vaddq_s32(
+ vmull_s16(vget_low_s16(q8sums), vget_low_s16(mins)),
+ vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)),
+ );
+ let sumi_mins = vaddvq_s32(prod);
+
+ let mut scales = utmp.as_ptr() as *const u8;
+
+ let mut q5 = x.qs.as_ptr();
+ let mut q8 = y.qs.as_ptr();
+
+ let mut qhbits = vld1q_u8_x2(x.qh.as_ptr());
+
+ let mut sumi = 0i32;
+
+ for _j in 0..QK_K / 64 {
+ let q5bits = vld1q_u8_x2(q5);
+ q5 = q5.add(32);
+ let q8bytes = vld1q_s8_x4(q8);
+ q8 = q8.add(64);
+
+ let q5h_0 = vshlq_n_u8(vandq_u8(mone, qhbits.0), 4);
+ let q5h_1 = vshlq_n_u8(vandq_u8(mone, qhbits.1), 4);
+ let q5h_2 = vshlq_n_u8(vandq_u8(mtwo, qhbits.0), 3);
+ let q5h_3 = vshlq_n_u8(vandq_u8(mtwo, qhbits.1), 3);
+ qhbits.0 = vshrq_n_u8(qhbits.0, 2);
+ qhbits.1 = vshrq_n_u8(qhbits.1, 2);
+
+ let q5bytes_0 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.0, m4b), q5h_0));
+ let q5bytes_1 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.1, m4b), q5h_1));
+ let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2));
+ let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3));
+
+ // TODO: dotprod
+
+ let p0 = vaddq_s16(
+ vmull_s8(vget_low_s8(q5bytes_0), vget_low_s8(q8bytes.0)),
+ vmull_s8(vget_high_s8(q5bytes_0), vget_high_s8(q8bytes.0)),
+ );
+ let p1 = vaddq_s16(
+ vmull_s8(vget_low_s8(q5bytes_1), vget_low_s8(q8bytes.1)),
+ vmull_s8(vget_high_s8(q5bytes_1), vget_high_s8(q8bytes.1)),
+ );
+ sumi += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * *scales as i32;
+ scales = scales.add(1);
+
+ let p2 = vaddq_s16(
+ vmull_s8(vget_low_s8(q5bytes_2), vget_low_s8(q8bytes.2)),
+ vmull_s8(vget_high_s8(q5bytes_2), vget_high_s8(q8bytes.2)),
+ );
+ let p3 = vaddq_s16(
+ vmull_s8(vget_low_s8(q5bytes_3), vget_low_s8(q8bytes.3)),
+ vmull_s8(vget_high_s8(q5bytes_3), vget_high_s8(q8bytes.3)),
+ );
+ sumi += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * *scales as i32;
+ scales = scales.add(1);
+ }
+ sumf += d * sumi as f32 - dmin * sumi_mins as f32;
+ }
+ }
+ Ok(sumf)
+}
+
+#[inline(always)]
pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
@@ -289,9 +389,9 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
let mut sumf = 0f32;
let mut utmp = [0u32; 4];
let mut scales = [0u8; 16];
- let kmask1: u32 = 0x3f3f3f3f;
- let kmask2: u32 = 0x0f0f0f0f;
- let kmask3: u32 = 0x03030303;
+ const KMASK1: u32 = 0x3f3f3f3f;
+ const KMASK2: u32 = 0x0f0f0f0f;
+ const KMASK3: u32 = 0x03030303;
unsafe {
let m4b = vdupq_n_u8(0xF);
@@ -309,13 +409,13 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
let mins8 = vld1_u32(
[
- utmp[1] & kmask1,
- ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4),
+ utmp[1] & KMASK1,
+ ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4),
]
.as_ptr(),
);
- utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
- utmp[0] &= kmask1;
+ utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
+ utmp[0] &= KMASK1;
let mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
let prod = vaddq_s32(
@@ -373,3 +473,255 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
}
Ok(sumf)
}
+
+#[inline(always)]
+pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Result<f32> {
+ if n % QK_K != 0 {
+ crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}")
+ }
+ let mut sumf = 0f32;
+ let mut utmp = [0u32; 4];
+ let mut aux = [0u32; 3];
+ const KMASK1: u32 = 0x03030303;
+ const KMASK2: u32 = 0x0f0f0f0f;
+
+ unsafe {
+ let m3b = vdupq_n_u8(0x3);
+ let m0 = vdupq_n_u8(1);
+ let m1 = vshlq_n_u8(m0, 1);
+ let m2 = vshlq_n_u8(m0, 2);
+ let m3 = vshlq_n_u8(m0, 3);
+ for (x, y) in xs.iter().zip(ys.iter()) {
+ let d = y.d * x.d.to_f32();
+ let mut q3 = x.qs.as_ptr();
+ let qh = x.hmask.as_ptr();
+ let mut q8 = y.qs.as_ptr();
+
+ let mut qhbits = vld1q_u8_x2(qh);
+
+ let mut isum = 0i32;
+
+ // Set up scales
+ LittleEndian::read_u32_into(&x.scales, &mut aux);
+
+ utmp[3] = ((aux[1] >> 4) & KMASK2) | (((aux[2] >> 6) & KMASK1) << 4);
+ utmp[2] = ((aux[0] >> 4) & KMASK2) | (((aux[2] >> 4) & KMASK1) << 4);
+ utmp[1] = (aux[1] & KMASK2) | (((aux[2] >> 2) & KMASK1) << 4);
+ utmp[0] = (aux[0] & KMASK2) | ((aux[2] & KMASK1) << 4);
+
+ let mut scale = utmp.as_mut_ptr() as *mut i8;
+ for j in 0..16 {
+ *scale.add(j) -= 32i8
+ }
+
+ for j in 0..QK_K / 128 {
+ let q3bits = vld1q_u8_x2(q3);
+ q3 = q3.add(32);
+ let q8bytes_1 = vld1q_s8_x4(q8);
+ q8 = q8.add(64);
+ let q8bytes_2 = vld1q_s8_x4(q8);
+ q8 = q8.add(64);
+
+ let q3h_0 = vshlq_n_u8(vbicq_u8(m0, qhbits.0), 2);
+ let q3h_1 = vshlq_n_u8(vbicq_u8(m0, qhbits.1), 2);
+ let q3h_2 = vshlq_n_u8(vbicq_u8(m1, qhbits.0), 1);
+ let q3h_3 = vshlq_n_u8(vbicq_u8(m1, qhbits.1), 1);
+
+ let q3bytes_0 = vsubq_s8(
+ vreinterpretq_s8_u8(vandq_u8(q3bits.0, m3b)),
+ vreinterpretq_s8_u8(q3h_0),
+ );
+ let q3bytes_1 = vsubq_s8(
+ vreinterpretq_s8_u8(vandq_u8(q3bits.1, m3b)),
+ vreinterpretq_s8_u8(q3h_1),
+ );
+ let q3bytes_2 = vsubq_s8(
+ vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 2), m3b)),
+ vreinterpretq_s8_u8(q3h_2),
+ );
+ let q3bytes_3 = vsubq_s8(
+ vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 2), m3b)),
+ vreinterpretq_s8_u8(q3h_3),
+ );
+
+ // TODO: dotprod
+ let p0 = vaddq_s16(
+ vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_1.0)),
+ vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_1.0)),
+ );
+ let p1 = vaddq_s16(
+ vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_1.1)),
+ vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_1.1)),
+ );
+ let p2 = vaddq_s16(
+ vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_1.2)),
+ vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_1.2)),
+ );
+ let p3 = vaddq_s16(
+ vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_1.3)),
+ vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_1.3)),
+ );
+ isum += vaddvq_s16(p0) as i32 * *scale as i32
+ + vaddvq_s16(p1) as i32 * *scale.add(1) as i32
+ + vaddvq_s16(p2) as i32 * *scale.add(2) as i32
+ + vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
+ scale = scale.add(4);
+
+ let q3h_0 = vbicq_u8(m2, qhbits.0);
+ let q3h_1 = vbicq_u8(m2, qhbits.1);
+ let q3h_2 = vshrq_n_u8(vbicq_u8(m3, qhbits.0), 1);
+ let q3h_3 = vshrq_n_u8(vbicq_u8(m3, qhbits.1), 1);
+
+ let q3bytes_0 = vsubq_s8(
+ vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 4), m3b)),
+ vreinterpretq_s8_u8(q3h_0),
+ );
+ let q3bytes_1 = vsubq_s8(
+ vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 4), m3b)),
+ vreinterpretq_s8_u8(q3h_1),
+ );
+ let q3bytes_2 = vsubq_s8(
+ vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 6), m3b)),
+ vreinterpretq_s8_u8(q3h_2),
+ );
+ let q3bytes_3 = vsubq_s8(
+ vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 6), m3b)),
+ vreinterpretq_s8_u8(q3h_3),
+ );
+
+ // TODO: dotprod
+ let p0 = vaddq_s16(
+ vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_2.0)),
+ vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_2.0)),
+ );
+ let p1 = vaddq_s16(
+ vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_2.1)),
+ vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_2.1)),
+ );
+ let p2 = vaddq_s16(
+ vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_2.2)),
+ vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_2.2)),
+ );
+ let p3 = vaddq_s16(
+ vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_2.3)),
+ vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_2.3)),
+ );
+ isum += vaddvq_s16(p0) as i32 * *scale as i32
+ + vaddvq_s16(p1) as i32 * *scale.add(1) as i32
+ + vaddvq_s16(p2) as i32 * *scale.add(2) as i32
+ + vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
+ scale = scale.add(4);
+
+ if j == 0 {
+ qhbits.0 = vshrq_n_u8(qhbits.0, 4);
+ qhbits.1 = vshrq_n_u8(qhbits.1, 4);
+ }
+ }
+ sumf += d * isum as f32;
+ }
+ }
+ Ok(sumf)
+}
+
+#[inline(always)]
+pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> {
+ if n % QK_K != 0 {
+ crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
+ }
+ let mut sumf = 0f32;
+ let mut aux = [0u8; 16];
+
+ unsafe {
+ let m3 = vdupq_n_u8(0x3);
+ let m4 = vdupq_n_u8(0xF);
+
+ for (x, y) in xs.iter().zip(ys.iter()) {
+ let d = y.d * x.d.to_f32();
+ let dmin = -y.d * x.dmin.to_f32();
+
+ let mut q2 = x.qs.as_ptr();
+ let mut q8 = y.qs.as_ptr();
+ let sc = x.scales.as_ptr();
+
+ let mins_and_scales = vld1q_u8(sc);
+ let scales = vandq_u8(mins_and_scales, m4);
+ vst1q_u8(aux.as_mut_ptr(), scales);
+
+ let mins = vshrq_n_u8(mins_and_scales, 4);
+ let q8sums = vld1q_s16_x2(y.bsums.as_ptr());
+ let mins16 = int16x8x2_t(
+ vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))),
+ vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins))),
+ );
+ let s0 = vaddq_s32(
+ vmull_s16(vget_low_s16(mins16.0), vget_low_s16(q8sums.0)),
+ vmull_s16(vget_high_s16(mins16.0), vget_high_s16(q8sums.0)),
+ );
+ let s1 = vaddq_s32(
+ vmull_s16(vget_low_s16(mins16.1), vget_low_s16(q8sums.1)),
+ vmull_s16(vget_high_s16(mins16.1), vget_high_s16(q8sums.1)),
+ );
+ sumf += dmin * vaddvq_s32(vaddq_s32(s0, s1)) as f32;
+
+ let mut isum = 0i32;
+ let mut is = 0usize;
+
+ // TODO: dotprod
+
+ for _j in 0..QK_K / 128 {
+ let q2bits = vld1q_u8_x2(q2);
+ q2 = q2.add(32);
+
+ let q8bytes = vld1q_s8_x2(q8);
+ q8 = q8.add(32);
+ let mut q2bytes = int8x16x2_t(
+ vreinterpretq_s8_u8(vandq_u8(q2bits.0, m3)),
+ vreinterpretq_s8_u8(vandq_u8(q2bits.1, m3)),
+ );
+ isum += multiply_accum_with_scale(&aux, is, 0, q2bytes, q8bytes);
+
+ let q8bytes = vld1q_s8_x2(q8);
+ q8 = q8.add(32);
+ q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 2), m3));
+ q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 2), m3));
+ isum += multiply_accum_with_scale(&aux, is, 2, q2bytes, q8bytes);
+
+ let q8bytes = vld1q_s8_x2(q8);
+ q8 = q8.add(32);
+ q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 4), m3));
+ q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 4), m3));
+ isum += multiply_accum_with_scale(&aux, is, 4, q2bytes, q8bytes);
+
+ let q8bytes = vld1q_s8_x2(q8);
+ q8 = q8.add(32);
+ q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 6), m3));
+ q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 6), m3));
+ isum += multiply_accum_with_scale(&aux, is, 6, q2bytes, q8bytes);
+
+ is += 8;
+ }
+ sumf += d * isum as f32;
+ }
+ }
+ Ok(sumf)
+}
+
+#[inline(always)]
+unsafe fn multiply_accum_with_scale(
+ aux: &[u8; 16],
+ is: usize,
+ index: usize,
+ q2bytes: int8x16x2_t,
+ q8bytes: int8x16x2_t,
+) -> i32 {
+ let p1 = vaddq_s16(
+ vmull_s8(vget_low_s8(q2bytes.0), vget_low_s8(q8bytes.0)),
+ vmull_s8(vget_high_s8(q2bytes.0), vget_high_s8(q8bytes.0)),
+ );
+ let p2 = vaddq_s16(
+ vmull_s8(vget_low_s8(q2bytes.1), vget_low_s8(q8bytes.1)),
+ vmull_s8(vget_high_s8(q2bytes.1), vget_high_s8(q8bytes.1)),
+ );
+ vaddvq_s16(p1) as i32 * aux[is + index] as i32
+ + vaddvq_s16(p2) as i32 * aux[is + 1 + index] as i32
+}