summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-02 23:26:34 +0100
committerGitHub <noreply@github.com>2023-10-02 23:26:34 +0100
commit7670fe7d1fa5cacda72c1ab201c5cc34d871ee46 (patch)
tree716e94e0368a9440be005c6da8af50f39c1cd890
parentcddfc3944cd7772230d71ba994c71e2dd5ba119e (diff)
downloadcandle-7670fe7d1fa5cacda72c1ab201c5cc34d871ee46.tar.gz
candle-7670fe7d1fa5cacda72c1ab201c5cc34d871ee46.tar.bz2
candle-7670fe7d1fa5cacda72c1ab201c5cc34d871ee46.zip
neon optimized q8k multiplication. (#1021)
* neon optimized q8k multiplication. * Bugfixes. * simdification.
-rw-r--r--candle-core/src/quantized/k_quants.rs10
-rw-r--r--candle-core/src/quantized/neon.rs29
2 files changed, 36 insertions, 3 deletions
diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs
index ac3f7def..80d36555 100644
--- a/candle-core/src/quantized/k_quants.rs
+++ b/candle-core/src/quantized/k_quants.rs
@@ -1756,14 +1756,18 @@ impl GgmlType for BlockQ8K {
const BLCK_SIZE: usize = QK_K;
type VecDotType = BlockQ8K;
+ #[allow(unreachable_code)]
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
+ #[cfg(target_feature = "neon")]
+ return super::neon::vec_dot_q8k_q8k(n, xs, ys);
+
Self::vec_dot_unopt(n, xs, ys)
}
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
- let qk = QK8_0;
- if n % QK8_0 != 0 {
- crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
+ let qk = QK_K;
+ if n % QK_K != 0 {
+ crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
}
// Generic implementation.
diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs
index 7f76dadc..fd4c1388 100644
--- a/candle-core/src/quantized/neon.rs
+++ b/candle-core/src/quantized/neon.rs
@@ -149,6 +149,35 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
}
#[inline(always)]
+pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
+ let qk = QK_K;
+ if n % QK_K != 0 {
+ crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
+ }
+
+ let mut sumf = 0f32;
+ for (xs, ys) in xs.iter().zip(ys.iter()) {
+ unsafe {
+ let mut sum_i = vdupq_n_s32(0);
+ let scale = xs.d * ys.d;
+ let xs = xs.qs.as_ptr();
+ let ys = ys.qs.as_ptr();
+ for i in (0..QK_K).step_by(16) {
+ let xs = vld1q_s8(xs.add(i));
+ let ys = vld1q_s8(ys.add(i));
+ let xy_lo = vmull_s8(vget_low_s8(xs), vget_low_s8(ys));
+ let xy_up = vmull_s8(vget_high_s8(xs), vget_high_s8(ys));
+
+ let xy = vaddq_s32(vpaddlq_s16(xy_lo), vpaddlq_s16(xy_up));
+ sum_i = vaddq_s32(sum_i, xy)
+ }
+ sumf += vaddvq_s32(sum_i) as f32 * scale
+ }
+ }
+ Ok(sumf)
+}
+
+#[inline(always)]
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}")