diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-03 15:29:48 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-03 15:29:48 +0100 |
commit | 11d3687cc655f8f79d856342a5539a9274e96df4 (patch) | |
tree | 81a54c285c8981400d5a5f9e48aa0cebb6e2b7a8 /candle-core/src/quantized | |
parent | dac73edb3468565fe9817166675db6e422a49767 (diff) | |
download | candle-11d3687cc655f8f79d856342a5539a9274e96df4.tar.gz candle-11d3687cc655f8f79d856342a5539a9274e96df4.tar.bz2 candle-11d3687cc655f8f79d856342a5539a9274e96df4.zip |
Simd128 optimized q8k vecdot. (#1026)
Diffstat (limited to 'candle-core/src/quantized')
-rw-r--r-- | candle-core/src/quantized/k_quants.rs | 3 | ||||
-rw-r--r-- | candle-core/src/quantized/simd128.rs | 30 |
2 files changed, 33 insertions, 0 deletions
diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index 7567c446..b140131e 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -1764,6 +1764,9 @@ impl GgmlType for BlockQ8K { #[cfg(target_feature = "neon")] return super::neon::vec_dot_q8k_q8k(n, xs, ys); + #[cfg(target_feature = "simd128")] + return super::simd128::vec_dot_q8k_q8k(n, xs, ys); + Self::vec_dot_unopt(n, xs, ys) } diff --git a/candle-core/src/quantized/simd128.rs b/candle-core/src/quantized/simd128.rs index cc26ac10..687399c2 100644 --- a/candle-core/src/quantized/simd128.rs +++ b/candle-core/src/quantized/simd128.rs @@ -395,3 +395,33 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res Ok(sums) } } + +#[inline(always)] +pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> { + let qk = QK_K; + if n % QK_K != 0 { + crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}") + } + + unsafe { + let mut acc = f32x4_splat(0.0f32); + for (xs, ys) in xs.iter().zip(ys.iter()) { + let x_qs = xs.qs.as_ptr(); + let y_qs = ys.qs.as_ptr(); + let mut sumi = i32x4_splat(0); + for j in (0..QK_K).step_by(8) { + let xs = i16x8_load_extend_i8x8(x_qs.add(j)); + let ys = i16x8_load_extend_i8x8(y_qs.add(j)); + let sum_xy = i32x4_dot_i16x8(xs, ys); + sumi = i32x4_add(sumi, sum_xy) + } + let d = f32x4_splat(xs.d * ys.d); + acc = f32x4_add(acc, f32x4_mul(f32x4_convert_i32x4(sumi), d)) + } + let res = f32x4_extract_lane::<0>(acc) + + f32x4_extract_lane::<1>(acc) + + f32x4_extract_lane::<2>(acc) + + f32x4_extract_lane::<3>(acc); + Ok(res) + } +} |