diff options
-rw-r--r-- | candle-core/src/quantized/k_quants.rs | 3 | ||||
-rw-r--r-- | candle-core/src/quantized/simd128.rs | 95 | ||||
-rw-r--r-- | candle-wasm-tests/tests/quantized_tests.rs | 6 |
3 files changed, 103 insertions, 1 deletions
diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index 064692b7..5b5ea4b0 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -1132,6 +1132,9 @@ impl GgmlType for BlockQ4K { #[cfg(target_feature = "neon")] return super::neon::vec_dot_q4k_q8k(n, xs, ys); + #[cfg(target_feature = "simd128")] + return super::simd128::vec_dot_q4k_q8k(n, xs, ys); + if n % QK_K != 0 { crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}") } diff --git a/candle-core/src/quantized/simd128.rs b/candle-core/src/quantized/simd128.rs index c093f189..6b225cce 100644 --- a/candle-core/src/quantized/simd128.rs +++ b/candle-core/src/quantized/simd128.rs @@ -1,5 +1,6 @@ -use super::k_quants::{BlockQ4_0, BlockQ8_0, QK8_0}; +use super::k_quants::{BlockQ4K, BlockQ4_0, BlockQ8K, BlockQ8_0, QK8_0, QK_K}; use crate::Result; +use byteorder::{ByteOrder, LittleEndian}; use half::f16; use core::arch::wasm32::*; @@ -97,3 +98,95 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Ok(res) } } + +#[inline(always)] +pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> { + if n % QK_K != 0 { + crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}") + } + + const KMASK1: u32 = 0x3f3f3f3f; + const KMASK2: u32 = 0x0f0f0f0f; + const KMASK3: u32 = 0x03030303; + + let mut utmp: [u32; 4] = [0; 4]; + let mut scales: [u8; 8] = [0; 8]; + let mut mins: [u8; 8] = [0; 8]; + + let mut aux8: [u8; QK_K] = [0; QK_K]; + let mut sums = f32x4_splat(0f32); + let mut sumf = f32x4_splat(0f32); + unsafe { + for (y, x) in ys.iter().zip(xs.iter()) { + let q4 = &x.qs; + let q8 = &y.qs; + + for j in 0..QK_K / 64 { + let q4_1 = v128_load(q4.as_ptr().add(32 * j) as *const v128); + let q4_2 = v128_load(q4.as_ptr().add(32 * j + 16) as *const v128); + v128_store( + aux8.as_mut_ptr().add(64 * j) as *mut v128, + v128_and(q4_1, u8x16_splat(0x0F)), + ); + v128_store( + aux8.as_mut_ptr().add(64 * j + 16) as *mut v128, + v128_and(q4_2, u8x16_splat(0x0F)), + ); + v128_store( + aux8.as_mut_ptr().add(64 * j + 32) as *mut v128, + u8x16_shr(q4_1, 4), + ); + v128_store( + aux8.as_mut_ptr().add(64 * j + 48) as *mut v128, + u8x16_shr(q4_2, 4), + ); + } + + LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]); + + utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4); + let uaux = utmp[1] & KMASK1; + utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4); + utmp[2] = uaux; + utmp[0] &= KMASK1; + + //extract scales and mins + LittleEndian::write_u32_into(&utmp[0..2], &mut scales); + LittleEndian::write_u32_into(&utmp[2..4], &mut mins); + + let mut sumi = i32x4_splat(0); + for j in (0..QK_K / 16).step_by(4) { + let bsums = i32x4_load_extend_i16x4(y.bsums.as_ptr().add(j)); + let (m1, m2) = (mins[j / 2] as i32, mins[j / 2 + 1] as i32); + let mins = i32x4(m1, m1, m2, m2); + sumi = i32x4_add(sumi, i32x4_mul(bsums, mins)); + } + + let mut aux32 = i32x4_splat(0i32); + for (scale_i, scale) in scales.iter().enumerate() { + let scale = i32x4_splat(*scale as i32); + for j in 0..4 { + let i = 32 * scale_i + 8 * j; + let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(i)); + let aux8 = i16x8_load_extend_u8x8(aux8.as_ptr().add(i)); + let aux16 = i16x8_mul(q8, aux8); + aux32 = i32x4_add(aux32, i32x4_mul(scale, i32x4_extend_low_i16x8(aux16))); + aux32 = i32x4_add(aux32, i32x4_mul(scale, i32x4_extend_high_i16x8(aux16))); + } + } + let aux32 = f32x4_convert_i32x4(aux32); + let d = f32x4_splat(x.d.to_f32() * y.d); + sums = f32x4_add(sums, f32x4_mul(aux32, d)); + let dmin = x.dmin.to_f32() * y.d; + let dmin = f32x4_splat(dmin); + let sumi = f32x4_convert_i32x4(sumi); + sumf = f32x4_add(sumf, f32x4_mul(sumi, dmin)); + } + let sums = f32x4_sub(sums, sumf); + let sums = f32x4_extract_lane::<0>(sums) + + f32x4_extract_lane::<1>(sums) + + f32x4_extract_lane::<2>(sums) + + f32x4_extract_lane::<3>(sums); + Ok(sums) + } +} diff --git a/candle-wasm-tests/tests/quantized_tests.rs b/candle-wasm-tests/tests/quantized_tests.rs index 0594a4fa..d16cf497 100644 --- a/candle-wasm-tests/tests/quantized_tests.rs +++ b/candle-wasm-tests/tests/quantized_tests.rs @@ -134,6 +134,12 @@ fn quantized_matmul_q40() -> Result<()> { } #[wasm_bindgen_test] +fn quantized_matmul_q4k() -> Result<()> { + ggml_matmul_error_test::<candle::quantized::k_quants::BlockQ4K>()?; + Ok(()) +} + +#[wasm_bindgen_test] fn quantized_matmul_q80() -> Result<()> { ggml_matmul_error_test::<candle::quantized::k_quants::BlockQ8_0>()?; Ok(()) |