summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/quantized/avx.rs399
-rw-r--r--candle-core/src/quantized/k_quants.rs12
2 files changed, 403 insertions, 8 deletions
diff --git a/candle-core/src/quantized/avx.rs b/candle-core/src/quantized/avx.rs
index f906d090..e5fa058d 100644
--- a/candle-core/src/quantized/avx.rs
+++ b/candle-core/src/quantized/avx.rs
@@ -1,4 +1,6 @@
-use super::k_quants::{BlockQ4K, BlockQ4_0, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K};
+use super::k_quants::{
+ BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K,
+};
use crate::Result;
use byteorder::{ByteOrder, LittleEndian};
use half::f16;
@@ -120,6 +122,18 @@ unsafe fn get_scale_shuffle_k4(i: usize) -> __m256i {
}
#[inline(always)]
+unsafe fn get_scale_shuffle_q3k(i: usize) -> __m256i {
+ const K_SHUFFLE: [u8; 128] = [
+ 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
+ 2, 3, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
+ 6, 7, 6, 7, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10, 11, 10, 11, 10, 11, 10, 11,
+ 10, 11, 10, 11, 10, 11, 10, 11, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12,
+ 13, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
+ ];
+ _mm256_loadu_si256((K_SHUFFLE.as_ptr() as *const __m256i).add(i))
+}
+
+#[inline(always)]
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
let qk = QK_K;
if n % qk != 0 {
@@ -212,15 +226,272 @@ unsafe fn mm256_set_m128i(a: __m128i, b: __m128i) -> __m256i {
_mm256_insertf128_si256(_mm256_castsi128_si256(b), a, 1)
}
+#[cfg_attr(not(debug_assertions), inline(always))]
+pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> {
+ if n % QK_K != 0 {
+ crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
+ }
+
+ unsafe {
+ let m3 = _mm256_set1_epi8(3);
+ let m4 = _mm_set1_epi8(0xF);
+
+ let mut acc = _mm256_setzero_ps();
+
+ for (x, y) in xs.iter().zip(ys.iter()) {
+ let d = y.d * x.d.to_f32();
+ let dmin = -y.d * x.dmin.to_f32();
+
+ let mut q2 = x.qs.as_ptr();
+ let mut q8 = y.qs.as_ptr();
+
+ let mins_and_scales = _mm_loadu_si128(x.scales.as_ptr() as *const __m128i);
+ let scales8 = _mm_and_si128(mins_and_scales, m4);
+ let mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
+ let mins = _mm256_cvtepi8_epi16(mins8);
+ let prod =
+ _mm256_madd_epi16(mins, _mm256_loadu_si256(y.bsums.as_ptr() as *const __m256i));
+
+ acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc);
+
+ let all_scales = _mm256_cvtepi8_epi16(scales8);
+ let l_scales = _mm256_extracti128_si256(all_scales, 0);
+ let h_scales = _mm256_extracti128_si256(all_scales, 1);
+ let scales = [
+ mm256_set_m128i(l_scales, l_scales),
+ mm256_set_m128i(h_scales, h_scales),
+ ];
+
+ let mut sumi = _mm256_setzero_si256();
+
+ for scale in scales {
+ let q2bits = _mm256_loadu_si256(q2 as *const __m256i);
+ q2 = q2.add(32);
+
+ let q8_0 = _mm256_loadu_si256(q8 as *const __m256i);
+ q8 = q8.add(32);
+ let q8_1 = _mm256_loadu_si256(q8 as *const __m256i);
+ q8 = q8.add(32);
+ let q8_2 = _mm256_loadu_si256(q8 as *const __m256i);
+ q8 = q8.add(32);
+ let q8_3 = _mm256_loadu_si256(q8 as *const __m256i);
+ q8 = q8.add(32);
+
+ let q2_0 = _mm256_and_si256(q2bits, m3);
+ let q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3);
+ let q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3);
+ let q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3);
+
+ let mut p0 = _mm256_maddubs_epi16(q2_0, q8_0);
+ let mut p1 = _mm256_maddubs_epi16(q2_1, q8_1);
+ let mut p2 = _mm256_maddubs_epi16(q2_2, q8_2);
+ let mut p3 = _mm256_maddubs_epi16(q2_3, q8_3);
+
+ p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(0)), p0);
+ p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(1)), p1);
+ p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(2)), p2);
+ p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(3)), p3);
+
+ p0 = _mm256_add_epi32(p0, p1);
+ p2 = _mm256_add_epi32(p2, p3);
+
+ sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2));
+ }
+ acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
+ }
+
+ Ok(hsum_float_8(acc))
+ }
+}
+
+#[cfg_attr(not(debug_assertions), inline(always))]
+pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Result<f32> {
+ if n % QK_K != 0 {
+ crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}")
+ }
+
+ const KMASK1: u32 = 0x03030303;
+ const KMASK2: u32 = 0x0f0f0f0f;
+
+ let mut aux = [0u32; 3];
+
+ unsafe {
+ let m3 = _mm256_set1_epi8(3);
+ let mone = _mm256_set1_epi8(1);
+ let m32 = _mm_set1_epi8(32);
+
+ let mut acc = _mm256_setzero_ps();
+ for (x, y) in xs.iter().zip(ys.iter()) {
+ let d = y.d * x.d.to_f32();
+
+ let mut q3 = x.qs.as_ptr();
+ let mut q8 = y.qs.as_ptr();
+
+ LittleEndian::read_u32_into(&x.scales, &mut aux);
+ let mut scales128 = _mm_set_epi32(
+ (((aux[1] >> 4) & KMASK2) | (((aux[2] >> 6) & KMASK1) << 4)) as i32,
+ (((aux[0] >> 4) & KMASK2) | (((aux[2] >> 4) & KMASK1) << 4)) as i32,
+ ((aux[1] & KMASK2) | (((aux[2] >> 2) & KMASK1) << 4)) as i32,
+ ((aux[0] & KMASK2) | (((aux[2]) & KMASK1) << 4)) as i32,
+ );
+ scales128 = _mm_sub_epi8(scales128, m32);
+ let all_scales = _mm256_cvtepi8_epi16(scales128);
+ let l_scales = _mm256_extracti128_si256(all_scales, 0);
+ let h_scales = _mm256_extracti128_si256(all_scales, 1);
+ let scales = [
+ mm256_set_m128i(l_scales, l_scales),
+ mm256_set_m128i(h_scales, h_scales),
+ ];
+
+ // high bit
+ let hbits = _mm256_loadu_si256(x.hmask.as_ptr() as *const __m256i);
+
+ // integer accumulator
+ let mut sumi = _mm256_setzero_si256();
+
+ for (j, scale) in scales.iter().enumerate() {
+ // load low 2 bits
+ let q3bits = _mm256_loadu_si256(q3 as *const __m256i);
+ q3 = q3.add(32);
+
+ // prepare low and high bits
+ //We hardcode the shifts here to avoid loading them into a seperate register
+ let q3l_0 = _mm256_and_si256(q3bits, m3);
+ let q3h_0 = if j == 0 {
+ _mm256_slli_epi16(
+ _mm256_srli_epi16(
+ _mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 0)),
+ 0,
+ ),
+ 2,
+ )
+ } else {
+ _mm256_slli_epi16(
+ _mm256_srli_epi16(
+ _mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 4)),
+ 4,
+ ),
+ 2,
+ )
+ };
+
+ let q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3);
+ let q3h_1 = if j == 0 {
+ _mm256_slli_epi16(
+ _mm256_srli_epi16(
+ _mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 1)),
+ 1,
+ ),
+ 2,
+ )
+ } else {
+ _mm256_slli_epi16(
+ _mm256_srli_epi16(
+ _mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 5)),
+ 5,
+ ),
+ 2,
+ )
+ };
+
+ let q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3);
+ let q3h_2 = if j == 0 {
+ _mm256_slli_epi16(
+ _mm256_srli_epi16(
+ _mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 2)),
+ 2,
+ ),
+ 2,
+ )
+ } else {
+ _mm256_slli_epi16(
+ _mm256_srli_epi16(
+ _mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 6)),
+ 6,
+ ),
+ 2,
+ )
+ };
+
+ let q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3);
+ let q3h_3 = if j == 0 {
+ _mm256_slli_epi16(
+ _mm256_srli_epi16(
+ _mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 3)),
+ 3,
+ ),
+ 2,
+ )
+ } else {
+ _mm256_slli_epi16(
+ _mm256_srli_epi16(
+ _mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 7)),
+ 7,
+ ),
+ 2,
+ )
+ };
+
+ // load Q8 quants
+ let q8_0 = _mm256_loadu_si256(q8 as *const __m256i);
+ q8 = q8.add(32);
+ let q8_1 = _mm256_loadu_si256(q8 as *const __m256i);
+ q8 = q8.add(32);
+ let q8_2 = _mm256_loadu_si256(q8 as *const __m256i);
+ q8 = q8.add(32);
+ let q8_3 = _mm256_loadu_si256(q8 as *const __m256i);
+ q8 = q8.add(32);
+
+ // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
+ // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
+ // and 2 if the high bit was set)
+ let q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);
+ let q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);
+ let q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2);
+ let q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3);
+
+ let mut p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);
+ let mut p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);
+ let mut p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2);
+ let mut p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3);
+
+ p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
+ p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
+ p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
+ p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
+
+ // multiply with scales
+ p16_0 =
+ _mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(0)), p16_0);
+ p16_1 =
+ _mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(1)), p16_1);
+ p16_2 =
+ _mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(2)), p16_2);
+ p16_3 =
+ _mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(3)), p16_3);
+
+ // accumulate
+ p16_0 = _mm256_add_epi32(p16_0, p16_1);
+ p16_2 = _mm256_add_epi32(p16_2, p16_3);
+ sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2));
+ }
+
+ // multiply with block scale and accumulate
+ acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
+ }
+ Ok(hsum_float_8(acc))
+ }
+}
+
#[inline(always)]
pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
}
let mut utmp = [0u32; 4];
- let kmask1: u32 = 0x3f3f3f3f;
- let kmask2: u32 = 0x0f0f0f0f;
- let kmask3: u32 = 0x03030303;
+ const KMASK1: u32 = 0x3f3f3f3f;
+ const KMASK2: u32 = 0x0f0f0f0f;
+ const KMASK3: u32 = 0x03030303;
unsafe {
let m4 = _mm256_set1_epi8(0xF);
@@ -234,11 +505,11 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
- utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
- let uaux = utmp[1] & kmask1;
- utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+ utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
+ let uaux = utmp[1] & KMASK1;
+ utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
utmp[2] = uaux;
- utmp[0] &= kmask1;
+ utmp[0] &= KMASK1;
let mut q4 = x.qs.as_ptr();
let mut q8 = y.qs.as_ptr();
@@ -295,3 +566,115 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
Ok(hsum_float_8(acc) + _mm_cvtss_f32(acc_m))
}
}
+
+#[cfg_attr(not(debug_assertions), inline(always))]
+pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Result<f32> {
+ if n % QK_K != 0 {
+ crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}")
+ }
+ let mut utmp = [0u32; 4];
+ const KMASK1: u32 = 0x3f3f3f3f;
+ const KMASK2: u32 = 0x0f0f0f0f;
+ const KMASK3: u32 = 0x03030303;
+
+ unsafe {
+ let m4 = _mm256_set1_epi8(0xF);
+ let mzero = _mm_setzero_si128();
+ let mone = _mm256_set1_epi8(1);
+
+ let mut acc = _mm256_setzero_ps();
+ let mut summs = 0.0;
+
+ for (x, y) in xs.iter().zip(ys.iter()) {
+ let d = y.d * x.d.to_f32();
+ let dmin = -y.d * x.dmin.to_f32();
+
+ LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
+
+ utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
+ let uaux = utmp[1] & KMASK1;
+ utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
+ utmp[2] = uaux;
+ utmp[0] &= KMASK1;
+
+ let mut q5 = x.qs.as_ptr();
+ let mut q8 = y.qs.as_ptr();
+
+ let mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(
+ utmp[3] as i32,
+ utmp[2] as i32,
+ utmp[1] as i32,
+ utmp[0] as i32,
+ ));
+
+ let q8sums = _mm256_loadu_si256(y.bsums.as_ptr() as *const __m256i);
+ let q8s = _mm_hadd_epi16(
+ _mm256_extracti128_si256(q8sums, 0),
+ _mm256_extracti128_si256(q8sums, 1),
+ );
+ let prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
+ let hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
+ summs += dmin * _mm_extract_epi32(hsum, 0) as f32;
+
+ let sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
+ let scales = mm256_set_m128i(sc128, sc128);
+
+ let hbits = _mm256_loadu_si256(x.qh.as_ptr() as *const __m256i);
+ let mut hmask = mone;
+
+ let mut sumi = _mm256_setzero_si256();
+
+ for j in 0..QK_K / 64 {
+ let scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j));
+ let scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j + 1));
+
+ let q5bits = _mm256_loadu_si256(q5 as *const __m256i);
+ q5 = q5.add(32);
+
+ //Similar to q3k we hardcode the shifts here to avoid loading them into a seperate register
+ let q5l_0 = _mm256_and_si256(q5bits, m4);
+ let q5l_0_shift_input = _mm256_and_si256(hbits, hmask);
+ let q5l_0_right_shift = match j {
+ 0 => _mm256_srli_epi16(q5l_0_shift_input, 0),
+ 1 => _mm256_srli_epi16(q5l_0_shift_input, 2),
+ 2 => _mm256_srli_epi16(q5l_0_shift_input, 4),
+ 3 => _mm256_srli_epi16(q5l_0_shift_input, 6),
+ _ => unreachable!(),
+ };
+ let q5h_0 = _mm256_slli_epi16(q5l_0_right_shift, 4);
+ let q5_0 = _mm256_add_epi8(q5l_0, q5h_0);
+ hmask = _mm256_slli_epi16(hmask, 1);
+
+ let q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
+ let q5l_1_shift_input = _mm256_and_si256(hbits, hmask);
+ let q5l_1_right_shift = match j {
+ 0 => _mm256_srli_epi16(q5l_1_shift_input, 1),
+ 1 => _mm256_srli_epi16(q5l_1_shift_input, 3),
+ 2 => _mm256_srli_epi16(q5l_1_shift_input, 5),
+ 3 => _mm256_srli_epi16(q5l_1_shift_input, 7),
+ _ => unreachable!(),
+ };
+
+ let q5h_1 = _mm256_slli_epi16(q5l_1_right_shift, 4);
+ let q5_1 = _mm256_add_epi8(q5l_1, q5h_1);
+ hmask = _mm256_slli_epi16(hmask, 1);
+
+ let q8_0 = _mm256_loadu_si256(q8 as *const __m256i);
+ q8 = q8.add(32);
+ let q8_1 = _mm256_loadu_si256(q8 as *const __m256i);
+ q8 = q8.add(32);
+
+ let mut p16_0 = _mm256_maddubs_epi16(q5_0, q8_0);
+ let mut p16_1 = _mm256_maddubs_epi16(q5_1, q8_1);
+
+ p16_0 = _mm256_madd_epi16(scale_0, p16_0);
+ p16_1 = _mm256_madd_epi16(scale_1, p16_1);
+
+ sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
+ }
+ let vd = _mm256_set1_ps(d);
+ acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
+ }
+ Ok(hsum_float_8(acc) + summs)
+ }
+}
diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs
index 7f14600b..e7404529 100644
--- a/candle-core/src/quantized/k_quants.rs
+++ b/candle-core/src/quantized/k_quants.rs
@@ -671,7 +671,11 @@ impl GgmlType for BlockQ2K {
const BLCK_SIZE: usize = QK_K;
type VecDotType = BlockQ8K;
+ #[allow(unreachable_code)]
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
+ #[cfg(target_feature = "avx")]
+ return super::avx::vec_dot_q2k_q8k(n, xs, ys);
+
if n % QK_K != 0 {
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
}
@@ -834,7 +838,11 @@ impl GgmlType for BlockQ3K {
const BLCK_SIZE: usize = QK_K;
type VecDotType = BlockQ8K;
+ #[allow(unreachable_code)]
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
+ #[cfg(target_feature = "avx")]
+ return super::avx::vec_dot_q3k_q8k(n, xs, ys);
+
if n % QK_K != 0 {
crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}")
}
@@ -1288,7 +1296,11 @@ impl GgmlType for BlockQ5K {
const BLCK_SIZE: usize = QK_K;
type VecDotType = BlockQ8K;
+ #[allow(unreachable_code)]
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
+ #[cfg(target_feature = "avx")]
+ return super::avx::vec_dot_q5k_q8k(n, xs, ys);
+
if n % QK_K != 0 {
crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}")
}