summaryrefslogtreecommitdiff
path: root/candle-core/src/quantized
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-03 15:29:48 +0100
committerGitHub <noreply@github.com>2023-10-03 15:29:48 +0100
commit11d3687cc655f8f79d856342a5539a9274e96df4 (patch)
tree81a54c285c8981400d5a5f9e48aa0cebb6e2b7a8 /candle-core/src/quantized
parentdac73edb3468565fe9817166675db6e422a49767 (diff)
downloadcandle-11d3687cc655f8f79d856342a5539a9274e96df4.tar.gz
candle-11d3687cc655f8f79d856342a5539a9274e96df4.tar.bz2
candle-11d3687cc655f8f79d856342a5539a9274e96df4.zip
Simd128 optimized q8k vecdot. (#1026)
Diffstat (limited to 'candle-core/src/quantized')
-rw-r--r--candle-core/src/quantized/k_quants.rs3
-rw-r--r--candle-core/src/quantized/simd128.rs30
2 files changed, 33 insertions, 0 deletions
diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs
index 7567c446..b140131e 100644
--- a/candle-core/src/quantized/k_quants.rs
+++ b/candle-core/src/quantized/k_quants.rs
@@ -1764,6 +1764,9 @@ impl GgmlType for BlockQ8K {
#[cfg(target_feature = "neon")]
return super::neon::vec_dot_q8k_q8k(n, xs, ys);
+ #[cfg(target_feature = "simd128")]
+ return super::simd128::vec_dot_q8k_q8k(n, xs, ys);
+
Self::vec_dot_unopt(n, xs, ys)
}
diff --git a/candle-core/src/quantized/simd128.rs b/candle-core/src/quantized/simd128.rs
index cc26ac10..687399c2 100644
--- a/candle-core/src/quantized/simd128.rs
+++ b/candle-core/src/quantized/simd128.rs
@@ -395,3 +395,33 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
Ok(sums)
}
}
+
+#[inline(always)]
+pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
+ let qk = QK_K;
+ if n % QK_K != 0 {
+ crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
+ }
+
+ unsafe {
+ let mut acc = f32x4_splat(0.0f32);
+ for (xs, ys) in xs.iter().zip(ys.iter()) {
+ let x_qs = xs.qs.as_ptr();
+ let y_qs = ys.qs.as_ptr();
+ let mut sumi = i32x4_splat(0);
+ for j in (0..QK_K).step_by(8) {
+ let xs = i16x8_load_extend_i8x8(x_qs.add(j));
+ let ys = i16x8_load_extend_i8x8(y_qs.add(j));
+ let sum_xy = i32x4_dot_i16x8(xs, ys);
+ sumi = i32x4_add(sumi, sum_xy)
+ }
+ let d = f32x4_splat(xs.d * ys.d);
+ acc = f32x4_add(acc, f32x4_mul(f32x4_convert_i32x4(sumi), d))
+ }
+ let res = f32x4_extract_lane::<0>(acc)
+ + f32x4_extract_lane::<1>(acc)
+ + f32x4_extract_lane::<2>(acc)
+ + f32x4_extract_lane::<3>(acc);
+ Ok(res)
+ }
+}