diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-27 20:19:38 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-27 20:19:38 +0100 |
commit | 9cb110c44c10145efefe3555ae8b9f91f4161fe2 (patch) | |
tree | 80dfda90beaa12102cbaa105567269c772cbe6b8 /candle-core/src | |
parent | 667f01c17323a5c28a9ae12d9f4512c36cc411b9 (diff) | |
download | candle-9cb110c44c10145efefe3555ae8b9f91f4161fe2.tar.gz candle-9cb110c44c10145efefe3555ae8b9f91f4161fe2.tar.bz2 candle-9cb110c44c10145efefe3555ae8b9f91f4161fe2.zip |
Sketch a simd128 optimized q4k vecdot. (#977)
* Sketch a simd128 optimized q4k vecdot.
* Simdify.
* More quantization optimizations.
* Again more simdification.
* Simdify the splitting loop.
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/quantized/k_quants.rs | 3 | ||||
-rw-r--r-- | candle-core/src/quantized/simd128.rs | 95 |
2 files changed, 97 insertions, 1 deletions
diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index 064692b7..5b5ea4b0 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -1132,6 +1132,9 @@ impl GgmlType for BlockQ4K { #[cfg(target_feature = "neon")] return super::neon::vec_dot_q4k_q8k(n, xs, ys); + #[cfg(target_feature = "simd128")] + return super::simd128::vec_dot_q4k_q8k(n, xs, ys); + if n % QK_K != 0 { crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}") } diff --git a/candle-core/src/quantized/simd128.rs b/candle-core/src/quantized/simd128.rs index c093f189..6b225cce 100644 --- a/candle-core/src/quantized/simd128.rs +++ b/candle-core/src/quantized/simd128.rs @@ -1,5 +1,6 @@ -use super::k_quants::{BlockQ4_0, BlockQ8_0, QK8_0}; +use super::k_quants::{BlockQ4K, BlockQ4_0, BlockQ8K, BlockQ8_0, QK8_0, QK_K}; use crate::Result; +use byteorder::{ByteOrder, LittleEndian}; use half::f16; use core::arch::wasm32::*; @@ -97,3 +98,95 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Ok(res) } } + +#[inline(always)] +pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> { + if n % QK_K != 0 { + crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}") + } + + const KMASK1: u32 = 0x3f3f3f3f; + const KMASK2: u32 = 0x0f0f0f0f; + const KMASK3: u32 = 0x03030303; + + let mut utmp: [u32; 4] = [0; 4]; + let mut scales: [u8; 8] = [0; 8]; + let mut mins: [u8; 8] = [0; 8]; + + let mut aux8: [u8; QK_K] = [0; QK_K]; + let mut sums = f32x4_splat(0f32); + let mut sumf = f32x4_splat(0f32); + unsafe { + for (y, x) in ys.iter().zip(xs.iter()) { + let q4 = &x.qs; + let q8 = &y.qs; + + for j in 0..QK_K / 64 { + let q4_1 = v128_load(q4.as_ptr().add(32 * j) as *const v128); + let q4_2 = v128_load(q4.as_ptr().add(32 * j + 16) as *const v128); + v128_store( + aux8.as_mut_ptr().add(64 * j) as *mut v128, + v128_and(q4_1, u8x16_splat(0x0F)), + ); + v128_store( + aux8.as_mut_ptr().add(64 * j + 16) as *mut v128, + v128_and(q4_2, u8x16_splat(0x0F)), + ); + v128_store( + aux8.as_mut_ptr().add(64 * j + 32) as *mut v128, + u8x16_shr(q4_1, 4), + ); + v128_store( + aux8.as_mut_ptr().add(64 * j + 48) as *mut v128, + u8x16_shr(q4_2, 4), + ); + } + + LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]); + + utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4); + let uaux = utmp[1] & KMASK1; + utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4); + utmp[2] = uaux; + utmp[0] &= KMASK1; + + //extract scales and mins + LittleEndian::write_u32_into(&utmp[0..2], &mut scales); + LittleEndian::write_u32_into(&utmp[2..4], &mut mins); + + let mut sumi = i32x4_splat(0); + for j in (0..QK_K / 16).step_by(4) { + let bsums = i32x4_load_extend_i16x4(y.bsums.as_ptr().add(j)); + let (m1, m2) = (mins[j / 2] as i32, mins[j / 2 + 1] as i32); + let mins = i32x4(m1, m1, m2, m2); + sumi = i32x4_add(sumi, i32x4_mul(bsums, mins)); + } + + let mut aux32 = i32x4_splat(0i32); + for (scale_i, scale) in scales.iter().enumerate() { + let scale = i32x4_splat(*scale as i32); + for j in 0..4 { + let i = 32 * scale_i + 8 * j; + let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(i)); + let aux8 = i16x8_load_extend_u8x8(aux8.as_ptr().add(i)); + let aux16 = i16x8_mul(q8, aux8); + aux32 = i32x4_add(aux32, i32x4_mul(scale, i32x4_extend_low_i16x8(aux16))); + aux32 = i32x4_add(aux32, i32x4_mul(scale, i32x4_extend_high_i16x8(aux16))); + } + } + let aux32 = f32x4_convert_i32x4(aux32); + let d = f32x4_splat(x.d.to_f32() * y.d); + sums = f32x4_add(sums, f32x4_mul(aux32, d)); + let dmin = x.dmin.to_f32() * y.d; + let dmin = f32x4_splat(dmin); + let sumi = f32x4_convert_i32x4(sumi); + sumf = f32x4_add(sumf, f32x4_mul(sumi, dmin)); + } + let sums = f32x4_sub(sums, sumf); + let sums = f32x4_extract_lane::<0>(sums) + + f32x4_extract_lane::<1>(sums) + + f32x4_extract_lane::<2>(sums) + + f32x4_extract_lane::<3>(sums); + Ok(sums) + } +} |