diff options
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/quantized/k_quants.rs | 9 | ||||
-rw-r--r-- | candle-core/src/quantized/neon.rs | 368 |
2 files changed, 369 insertions, 8 deletions
diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index e7404529..65fd6a6e 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -676,6 +676,9 @@ impl GgmlType for BlockQ2K { #[cfg(target_feature = "avx")] return super::avx::vec_dot_q2k_q8k(n, xs, ys); + #[cfg(target_feature = "neon")] + return super::neon::vec_dot_q2k_q8k(n, xs, ys); + if n % QK_K != 0 { crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}") } @@ -843,6 +846,9 @@ impl GgmlType for BlockQ3K { #[cfg(target_feature = "avx")] return super::avx::vec_dot_q3k_q8k(n, xs, ys); + #[cfg(target_feature = "neon")] + return super::neon::vec_dot_q3k_q8k(n, xs, ys); + if n % QK_K != 0 { crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}") } @@ -1301,6 +1307,9 @@ impl GgmlType for BlockQ5K { #[cfg(target_feature = "avx")] return super::avx::vec_dot_q5k_q8k(n, xs, ys); + #[cfg(target_feature = "neon")] + return super::neon::vec_dot_q5k_q8k(n, xs, ys); + if n % QK_K != 0 { crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}") } diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs index 69d616f4..7f76dadc 100644 --- a/candle-core/src/quantized/neon.rs +++ b/candle-core/src/quantized/neon.rs @@ -1,4 +1,6 @@ -use super::k_quants::{BlockQ4K, BlockQ4_0, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K}; +use super::k_quants::{ + BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K, +}; use crate::Result; use byteorder::{ByteOrder, LittleEndian}; @@ -282,6 +284,104 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res } #[inline(always)] +pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Result<f32> { + if n % QK_K != 0 { + crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}") + } + let mut sumf = 0f32; + let mut utmp = [0u32; 4]; + const KMASK1: u32 = 0x3f3f3f3f; + const KMASK2: u32 = 0x0f0f0f0f; + const KMASK3: u32 = 0x03030303; + + unsafe { + let m4b = vdupq_n_u8(0xF); + let mone = vdupq_n_u8(1); + let mtwo = vdupq_n_u8(2); + + for (x, y) in xs.iter().zip(ys.iter()) { + let d = y.d * x.d.to_f32(); + let dmin = y.d * x.dmin.to_f32(); + + let q8sums = vpaddq_s16( + vld1q_s16(y.bsums.as_ptr()), + vld1q_s16(y.bsums.as_ptr().add(8)), + ); + + LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]); + + utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4); + let uaux = utmp[1] & KMASK1; + utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4); + utmp[2] = uaux; + utmp[0] &= KMASK1; + + let mins8 = vld1_u8((utmp.as_ptr() as *const u8).add(8)); + let mins = vreinterpretq_s16_u16(vmovl_u8(mins8)); + let prod = vaddq_s32( + vmull_s16(vget_low_s16(q8sums), vget_low_s16(mins)), + vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)), + ); + let sumi_mins = vaddvq_s32(prod); + + let mut scales = utmp.as_ptr() as *const u8; + + let mut q5 = x.qs.as_ptr(); + let mut q8 = y.qs.as_ptr(); + + let mut qhbits = vld1q_u8_x2(x.qh.as_ptr()); + + let mut sumi = 0i32; + + for _j in 0..QK_K / 64 { + let q5bits = vld1q_u8_x2(q5); + q5 = q5.add(32); + let q8bytes = vld1q_s8_x4(q8); + q8 = q8.add(64); + + let q5h_0 = vshlq_n_u8(vandq_u8(mone, qhbits.0), 4); + let q5h_1 = vshlq_n_u8(vandq_u8(mone, qhbits.1), 4); + let q5h_2 = vshlq_n_u8(vandq_u8(mtwo, qhbits.0), 3); + let q5h_3 = vshlq_n_u8(vandq_u8(mtwo, qhbits.1), 3); + qhbits.0 = vshrq_n_u8(qhbits.0, 2); + qhbits.1 = vshrq_n_u8(qhbits.1, 2); + + let q5bytes_0 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.0, m4b), q5h_0)); + let q5bytes_1 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.1, m4b), q5h_1)); + let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2)); + let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3)); + + // TODO: dotprod + + let p0 = vaddq_s16( + vmull_s8(vget_low_s8(q5bytes_0), vget_low_s8(q8bytes.0)), + vmull_s8(vget_high_s8(q5bytes_0), vget_high_s8(q8bytes.0)), + ); + let p1 = vaddq_s16( + vmull_s8(vget_low_s8(q5bytes_1), vget_low_s8(q8bytes.1)), + vmull_s8(vget_high_s8(q5bytes_1), vget_high_s8(q8bytes.1)), + ); + sumi += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * *scales as i32; + scales = scales.add(1); + + let p2 = vaddq_s16( + vmull_s8(vget_low_s8(q5bytes_2), vget_low_s8(q8bytes.2)), + vmull_s8(vget_high_s8(q5bytes_2), vget_high_s8(q8bytes.2)), + ); + let p3 = vaddq_s16( + vmull_s8(vget_low_s8(q5bytes_3), vget_low_s8(q8bytes.3)), + vmull_s8(vget_high_s8(q5bytes_3), vget_high_s8(q8bytes.3)), + ); + sumi += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * *scales as i32; + scales = scales.add(1); + } + sumf += d * sumi as f32 - dmin * sumi_mins as f32; + } + } + Ok(sumf) +} + +#[inline(always)] pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> { if n % QK_K != 0 { crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}") @@ -289,9 +389,9 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res let mut sumf = 0f32; let mut utmp = [0u32; 4]; let mut scales = [0u8; 16]; - let kmask1: u32 = 0x3f3f3f3f; - let kmask2: u32 = 0x0f0f0f0f; - let kmask3: u32 = 0x03030303; + const KMASK1: u32 = 0x3f3f3f3f; + const KMASK2: u32 = 0x0f0f0f0f; + const KMASK3: u32 = 0x03030303; unsafe { let m4b = vdupq_n_u8(0xF); @@ -309,13 +409,13 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res let mins8 = vld1_u32( [ - utmp[1] & kmask1, - ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), + utmp[1] & KMASK1, + ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4), ] .as_ptr(), ); - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[0] &= kmask1; + utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4); + utmp[0] &= KMASK1; let mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8))); let prod = vaddq_s32( @@ -373,3 +473,255 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res } Ok(sumf) } + +#[inline(always)] +pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Result<f32> { + if n % QK_K != 0 { + crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}") + } + let mut sumf = 0f32; + let mut utmp = [0u32; 4]; + let mut aux = [0u32; 3]; + const KMASK1: u32 = 0x03030303; + const KMASK2: u32 = 0x0f0f0f0f; + + unsafe { + let m3b = vdupq_n_u8(0x3); + let m0 = vdupq_n_u8(1); + let m1 = vshlq_n_u8(m0, 1); + let m2 = vshlq_n_u8(m0, 2); + let m3 = vshlq_n_u8(m0, 3); + for (x, y) in xs.iter().zip(ys.iter()) { + let d = y.d * x.d.to_f32(); + let mut q3 = x.qs.as_ptr(); + let qh = x.hmask.as_ptr(); + let mut q8 = y.qs.as_ptr(); + + let mut qhbits = vld1q_u8_x2(qh); + + let mut isum = 0i32; + + // Set up scales + LittleEndian::read_u32_into(&x.scales, &mut aux); + + utmp[3] = ((aux[1] >> 4) & KMASK2) | (((aux[2] >> 6) & KMASK1) << 4); + utmp[2] = ((aux[0] >> 4) & KMASK2) | (((aux[2] >> 4) & KMASK1) << 4); + utmp[1] = (aux[1] & KMASK2) | (((aux[2] >> 2) & KMASK1) << 4); + utmp[0] = (aux[0] & KMASK2) | ((aux[2] & KMASK1) << 4); + + let mut scale = utmp.as_mut_ptr() as *mut i8; + for j in 0..16 { + *scale.add(j) -= 32i8 + } + + for j in 0..QK_K / 128 { + let q3bits = vld1q_u8_x2(q3); + q3 = q3.add(32); + let q8bytes_1 = vld1q_s8_x4(q8); + q8 = q8.add(64); + let q8bytes_2 = vld1q_s8_x4(q8); + q8 = q8.add(64); + + let q3h_0 = vshlq_n_u8(vbicq_u8(m0, qhbits.0), 2); + let q3h_1 = vshlq_n_u8(vbicq_u8(m0, qhbits.1), 2); + let q3h_2 = vshlq_n_u8(vbicq_u8(m1, qhbits.0), 1); + let q3h_3 = vshlq_n_u8(vbicq_u8(m1, qhbits.1), 1); + + let q3bytes_0 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(q3bits.0, m3b)), + vreinterpretq_s8_u8(q3h_0), + ); + let q3bytes_1 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(q3bits.1, m3b)), + vreinterpretq_s8_u8(q3h_1), + ); + let q3bytes_2 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 2), m3b)), + vreinterpretq_s8_u8(q3h_2), + ); + let q3bytes_3 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 2), m3b)), + vreinterpretq_s8_u8(q3h_3), + ); + + // TODO: dotprod + let p0 = vaddq_s16( + vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_1.0)), + vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_1.0)), + ); + let p1 = vaddq_s16( + vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_1.1)), + vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_1.1)), + ); + let p2 = vaddq_s16( + vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_1.2)), + vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_1.2)), + ); + let p3 = vaddq_s16( + vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_1.3)), + vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_1.3)), + ); + isum += vaddvq_s16(p0) as i32 * *scale as i32 + + vaddvq_s16(p1) as i32 * *scale.add(1) as i32 + + vaddvq_s16(p2) as i32 * *scale.add(2) as i32 + + vaddvq_s16(p3) as i32 * *scale.add(3) as i32; + scale = scale.add(4); + + let q3h_0 = vbicq_u8(m2, qhbits.0); + let q3h_1 = vbicq_u8(m2, qhbits.1); + let q3h_2 = vshrq_n_u8(vbicq_u8(m3, qhbits.0), 1); + let q3h_3 = vshrq_n_u8(vbicq_u8(m3, qhbits.1), 1); + + let q3bytes_0 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 4), m3b)), + vreinterpretq_s8_u8(q3h_0), + ); + let q3bytes_1 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 4), m3b)), + vreinterpretq_s8_u8(q3h_1), + ); + let q3bytes_2 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 6), m3b)), + vreinterpretq_s8_u8(q3h_2), + ); + let q3bytes_3 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 6), m3b)), + vreinterpretq_s8_u8(q3h_3), + ); + + // TODO: dotprod + let p0 = vaddq_s16( + vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_2.0)), + vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_2.0)), + ); + let p1 = vaddq_s16( + vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_2.1)), + vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_2.1)), + ); + let p2 = vaddq_s16( + vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_2.2)), + vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_2.2)), + ); + let p3 = vaddq_s16( + vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_2.3)), + vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_2.3)), + ); + isum += vaddvq_s16(p0) as i32 * *scale as i32 + + vaddvq_s16(p1) as i32 * *scale.add(1) as i32 + + vaddvq_s16(p2) as i32 * *scale.add(2) as i32 + + vaddvq_s16(p3) as i32 * *scale.add(3) as i32; + scale = scale.add(4); + + if j == 0 { + qhbits.0 = vshrq_n_u8(qhbits.0, 4); + qhbits.1 = vshrq_n_u8(qhbits.1, 4); + } + } + sumf += d * isum as f32; + } + } + Ok(sumf) +} + +#[inline(always)] +pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> { + if n % QK_K != 0 { + crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}") + } + let mut sumf = 0f32; + let mut aux = [0u8; 16]; + + unsafe { + let m3 = vdupq_n_u8(0x3); + let m4 = vdupq_n_u8(0xF); + + for (x, y) in xs.iter().zip(ys.iter()) { + let d = y.d * x.d.to_f32(); + let dmin = -y.d * x.dmin.to_f32(); + + let mut q2 = x.qs.as_ptr(); + let mut q8 = y.qs.as_ptr(); + let sc = x.scales.as_ptr(); + + let mins_and_scales = vld1q_u8(sc); + let scales = vandq_u8(mins_and_scales, m4); + vst1q_u8(aux.as_mut_ptr(), scales); + + let mins = vshrq_n_u8(mins_and_scales, 4); + let q8sums = vld1q_s16_x2(y.bsums.as_ptr()); + let mins16 = int16x8x2_t( + vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), + vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins))), + ); + let s0 = vaddq_s32( + vmull_s16(vget_low_s16(mins16.0), vget_low_s16(q8sums.0)), + vmull_s16(vget_high_s16(mins16.0), vget_high_s16(q8sums.0)), + ); + let s1 = vaddq_s32( + vmull_s16(vget_low_s16(mins16.1), vget_low_s16(q8sums.1)), + vmull_s16(vget_high_s16(mins16.1), vget_high_s16(q8sums.1)), + ); + sumf += dmin * vaddvq_s32(vaddq_s32(s0, s1)) as f32; + + let mut isum = 0i32; + let mut is = 0usize; + + // TODO: dotprod + + for _j in 0..QK_K / 128 { + let q2bits = vld1q_u8_x2(q2); + q2 = q2.add(32); + + let q8bytes = vld1q_s8_x2(q8); + q8 = q8.add(32); + let mut q2bytes = int8x16x2_t( + vreinterpretq_s8_u8(vandq_u8(q2bits.0, m3)), + vreinterpretq_s8_u8(vandq_u8(q2bits.1, m3)), + ); + isum += multiply_accum_with_scale(&aux, is, 0, q2bytes, q8bytes); + + let q8bytes = vld1q_s8_x2(q8); + q8 = q8.add(32); + q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 2), m3)); + q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 2), m3)); + isum += multiply_accum_with_scale(&aux, is, 2, q2bytes, q8bytes); + + let q8bytes = vld1q_s8_x2(q8); + q8 = q8.add(32); + q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 4), m3)); + q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 4), m3)); + isum += multiply_accum_with_scale(&aux, is, 4, q2bytes, q8bytes); + + let q8bytes = vld1q_s8_x2(q8); + q8 = q8.add(32); + q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 6), m3)); + q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 6), m3)); + isum += multiply_accum_with_scale(&aux, is, 6, q2bytes, q8bytes); + + is += 8; + } + sumf += d * isum as f32; + } + } + Ok(sumf) +} + +#[inline(always)] +unsafe fn multiply_accum_with_scale( + aux: &[u8; 16], + is: usize, + index: usize, + q2bytes: int8x16x2_t, + q8bytes: int8x16x2_t, +) -> i32 { + let p1 = vaddq_s16( + vmull_s8(vget_low_s8(q2bytes.0), vget_low_s8(q8bytes.0)), + vmull_s8(vget_high_s8(q2bytes.0), vget_high_s8(q8bytes.0)), + ); + let p2 = vaddq_s16( + vmull_s8(vget_low_s8(q2bytes.1), vget_low_s8(q8bytes.1)), + vmull_s8(vget_high_s8(q2bytes.1), vget_high_s8(q8bytes.1)), + ); + vaddvq_s16(p1) as i32 * aux[is + index] as i32 + + vaddvq_s16(p2) as i32 * aux[is + 1 + index] as i32 +} |