summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-27 20:19:38 +0100
committerGitHub <noreply@github.com>2023-09-27 20:19:38 +0100
commit9cb110c44c10145efefe3555ae8b9f91f4161fe2 (patch)
tree80dfda90beaa12102cbaa105567269c772cbe6b8 /candle-core/src
parent667f01c17323a5c28a9ae12d9f4512c36cc411b9 (diff)
downloadcandle-9cb110c44c10145efefe3555ae8b9f91f4161fe2.tar.gz
candle-9cb110c44c10145efefe3555ae8b9f91f4161fe2.tar.bz2
candle-9cb110c44c10145efefe3555ae8b9f91f4161fe2.zip
Sketch a simd128 optimized q4k vecdot. (#977)
* Sketch a simd128 optimized q4k vecdot. * Simdify. * More quantization optimizations. * Again more simdification. * Simdify the splitting loop.
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/quantized/k_quants.rs3
-rw-r--r--candle-core/src/quantized/simd128.rs95
2 files changed, 97 insertions, 1 deletions
diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs
index 064692b7..5b5ea4b0 100644
--- a/candle-core/src/quantized/k_quants.rs
+++ b/candle-core/src/quantized/k_quants.rs
@@ -1132,6 +1132,9 @@ impl GgmlType for BlockQ4K {
#[cfg(target_feature = "neon")]
return super::neon::vec_dot_q4k_q8k(n, xs, ys);
+ #[cfg(target_feature = "simd128")]
+ return super::simd128::vec_dot_q4k_q8k(n, xs, ys);
+
if n % QK_K != 0 {
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
}
diff --git a/candle-core/src/quantized/simd128.rs b/candle-core/src/quantized/simd128.rs
index c093f189..6b225cce 100644
--- a/candle-core/src/quantized/simd128.rs
+++ b/candle-core/src/quantized/simd128.rs
@@ -1,5 +1,6 @@
-use super::k_quants::{BlockQ4_0, BlockQ8_0, QK8_0};
+use super::k_quants::{BlockQ4K, BlockQ4_0, BlockQ8K, BlockQ8_0, QK8_0, QK_K};
use crate::Result;
+use byteorder::{ByteOrder, LittleEndian};
use half::f16;
use core::arch::wasm32::*;
@@ -97,3 +98,95 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
Ok(res)
}
}
+
+#[inline(always)]
+pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
+ if n % QK_K != 0 {
+ crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
+ }
+
+ const KMASK1: u32 = 0x3f3f3f3f;
+ const KMASK2: u32 = 0x0f0f0f0f;
+ const KMASK3: u32 = 0x03030303;
+
+ let mut utmp: [u32; 4] = [0; 4];
+ let mut scales: [u8; 8] = [0; 8];
+ let mut mins: [u8; 8] = [0; 8];
+
+ let mut aux8: [u8; QK_K] = [0; QK_K];
+ let mut sums = f32x4_splat(0f32);
+ let mut sumf = f32x4_splat(0f32);
+ unsafe {
+ for (y, x) in ys.iter().zip(xs.iter()) {
+ let q4 = &x.qs;
+ let q8 = &y.qs;
+
+ for j in 0..QK_K / 64 {
+ let q4_1 = v128_load(q4.as_ptr().add(32 * j) as *const v128);
+ let q4_2 = v128_load(q4.as_ptr().add(32 * j + 16) as *const v128);
+ v128_store(
+ aux8.as_mut_ptr().add(64 * j) as *mut v128,
+ v128_and(q4_1, u8x16_splat(0x0F)),
+ );
+ v128_store(
+ aux8.as_mut_ptr().add(64 * j + 16) as *mut v128,
+ v128_and(q4_2, u8x16_splat(0x0F)),
+ );
+ v128_store(
+ aux8.as_mut_ptr().add(64 * j + 32) as *mut v128,
+ u8x16_shr(q4_1, 4),
+ );
+ v128_store(
+ aux8.as_mut_ptr().add(64 * j + 48) as *mut v128,
+ u8x16_shr(q4_2, 4),
+ );
+ }
+
+ LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
+
+ utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
+ let uaux = utmp[1] & KMASK1;
+ utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
+ utmp[2] = uaux;
+ utmp[0] &= KMASK1;
+
+ //extract scales and mins
+ LittleEndian::write_u32_into(&utmp[0..2], &mut scales);
+ LittleEndian::write_u32_into(&utmp[2..4], &mut mins);
+
+ let mut sumi = i32x4_splat(0);
+ for j in (0..QK_K / 16).step_by(4) {
+ let bsums = i32x4_load_extend_i16x4(y.bsums.as_ptr().add(j));
+ let (m1, m2) = (mins[j / 2] as i32, mins[j / 2 + 1] as i32);
+ let mins = i32x4(m1, m1, m2, m2);
+ sumi = i32x4_add(sumi, i32x4_mul(bsums, mins));
+ }
+
+ let mut aux32 = i32x4_splat(0i32);
+ for (scale_i, scale) in scales.iter().enumerate() {
+ let scale = i32x4_splat(*scale as i32);
+ for j in 0..4 {
+ let i = 32 * scale_i + 8 * j;
+ let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(i));
+ let aux8 = i16x8_load_extend_u8x8(aux8.as_ptr().add(i));
+ let aux16 = i16x8_mul(q8, aux8);
+ aux32 = i32x4_add(aux32, i32x4_mul(scale, i32x4_extend_low_i16x8(aux16)));
+ aux32 = i32x4_add(aux32, i32x4_mul(scale, i32x4_extend_high_i16x8(aux16)));
+ }
+ }
+ let aux32 = f32x4_convert_i32x4(aux32);
+ let d = f32x4_splat(x.d.to_f32() * y.d);
+ sums = f32x4_add(sums, f32x4_mul(aux32, d));
+ let dmin = x.dmin.to_f32() * y.d;
+ let dmin = f32x4_splat(dmin);
+ let sumi = f32x4_convert_i32x4(sumi);
+ sumf = f32x4_add(sumf, f32x4_mul(sumi, dmin));
+ }
+ let sums = f32x4_sub(sums, sumf);
+ let sums = f32x4_extract_lane::<0>(sums)
+ + f32x4_extract_lane::<1>(sums)
+ + f32x4_extract_lane::<2>(sums)
+ + f32x4_extract_lane::<3>(sums);
+ Ok(sums)
+ }
+}