summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/quantized/avx.rs126
-rw-r--r--candle-core/src/quantized/k_quants.rs3
2 files changed, 120 insertions, 9 deletions
diff --git a/candle-core/src/quantized/avx.rs b/candle-core/src/quantized/avx.rs
index 96087feb..f906d090 100644
--- a/candle-core/src/quantized/avx.rs
+++ b/candle-core/src/quantized/avx.rs
@@ -1,5 +1,6 @@
-use super::k_quants::{BlockQ4_0, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K};
+use super::k_quants::{BlockQ4K, BlockQ4_0, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K};
use crate::Result;
+use byteorder::{ByteOrder, LittleEndian};
use half::f16;
#[cfg(target_arch = "x86")]
@@ -89,17 +90,35 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
}
}
-const K_SHUFFLE: [u8; 128] = [
- 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
- 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,
- 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11,
- 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13, 13, 14, 14, 14, 14,
- 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15, 15,
-];
-
+#[inline(always)]
unsafe fn get_scale_shuffle(i: usize) -> __m128i {
+ const K_SHUFFLE: [u8; 128] = [
+ 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3,
+ 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7,
+ 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10,
+ 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13,
+ 13, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15, 15,
+ ];
_mm_loadu_si128((K_SHUFFLE.as_ptr() as *const __m128i).add(i))
}
+
+#[inline(always)]
+unsafe fn get_scale_shuffle_k4(i: usize) -> __m256i {
+ const K_SHUFFLE: [u8; 256] = [
+ 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
+ 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
+ 2, 3, 2, 3, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
+ 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
+ 6, 7, 6, 7, 6, 7, 6, 7, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
+ 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10,
+ 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 10, 11, 12, 13, 12, 13, 12, 13,
+ 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12,
+ 13, 12, 13, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
+ 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
+ ];
+ _mm256_loadu_si256((K_SHUFFLE.as_ptr() as *const __m256i).add(i))
+}
+
#[inline(always)]
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
let qk = QK_K;
@@ -187,3 +206,92 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
Ok(hsum_float_8(acc))
}
}
+
+#[inline(always)]
+unsafe fn mm256_set_m128i(a: __m128i, b: __m128i) -> __m256i {
+ _mm256_insertf128_si256(_mm256_castsi128_si256(b), a, 1)
+}
+
+#[inline(always)]
+pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
+ if n % QK_K != 0 {
+ crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
+ }
+ let mut utmp = [0u32; 4];
+ let kmask1: u32 = 0x3f3f3f3f;
+ let kmask2: u32 = 0x0f0f0f0f;
+ let kmask3: u32 = 0x03030303;
+
+ unsafe {
+ let m4 = _mm256_set1_epi8(0xF);
+
+ let mut acc = _mm256_setzero_ps();
+ let mut acc_m = _mm_setzero_ps();
+
+ for (x, y) in xs.iter().zip(ys.iter()) {
+ let d = y.d * x.d.to_f32();
+ let dmin = -y.d * x.dmin.to_f32();
+
+ LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
+
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+ let uaux = utmp[1] & kmask1;
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+ utmp[2] = uaux;
+ utmp[0] &= kmask1;
+
+ let mut q4 = x.qs.as_ptr();
+ let mut q8 = y.qs.as_ptr();
+
+ let mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(
+ utmp[3] as i32,
+ utmp[2] as i32,
+ utmp[1] as i32,
+ utmp[0] as i32,
+ ));
+
+ let q8sums = _mm256_loadu_si256(y.bsums.as_ptr() as *const __m256i);
+ let q8s = _mm_hadd_epi16(
+ _mm256_extracti128_si256(q8sums, 0),
+ _mm256_extracti128_si256(q8sums, 1),
+ );
+ let prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
+ acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m);
+
+ let sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
+ let scales = mm256_set_m128i(sc128, sc128);
+
+ let mut sumi = _mm256_setzero_si256();
+
+ for j in 0..QK_K / 64 {
+ let scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j));
+ let scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j + 1));
+
+ let q4bits = _mm256_loadu_si256(q4 as *const __m256i);
+ q4 = q4.add(32);
+ let q4l = _mm256_and_si256(q4bits, m4);
+ let q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
+
+ let q8l = _mm256_loadu_si256(q8 as *const __m256i);
+ q8 = q8.add(32);
+ let p16l = _mm256_maddubs_epi16(q4l, q8l);
+ let p16l = _mm256_madd_epi16(scale_l, p16l);
+ sumi = _mm256_add_epi32(sumi, p16l);
+
+ let q8h = _mm256_loadu_si256(q8 as *const __m256i);
+ q8 = q8.add(32);
+ let p16h = _mm256_maddubs_epi16(q4h, q8h);
+ let p16h = _mm256_madd_epi16(scale_h, p16h);
+ sumi = _mm256_add_epi32(sumi, p16h);
+ }
+
+ let vd = _mm256_set1_ps(d);
+ acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
+ }
+
+ let acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
+ let acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
+
+ Ok(hsum_float_8(acc) + _mm_cvtss_f32(acc_m))
+ }
+}
diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs
index 7b405ec9..7f14600b 100644
--- a/candle-core/src/quantized/k_quants.rs
+++ b/candle-core/src/quantized/k_quants.rs
@@ -1104,6 +1104,9 @@ impl GgmlType for BlockQ4K {
#[allow(unreachable_code)]
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
+ #[cfg(target_feature = "avx")]
+ return super::avx::vec_dot_q4k_q8k(n, xs, ys);
+
#[cfg(target_feature = "neon")]
return super::neon::vec_dot_q4k_q8k(n, xs, ys);