diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-02 23:26:34 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-02 23:26:34 +0100 |
commit | 7670fe7d1fa5cacda72c1ab201c5cc34d871ee46 (patch) | |
tree | 716e94e0368a9440be005c6da8af50f39c1cd890 | |
parent | cddfc3944cd7772230d71ba994c71e2dd5ba119e (diff) | |
download | candle-7670fe7d1fa5cacda72c1ab201c5cc34d871ee46.tar.gz candle-7670fe7d1fa5cacda72c1ab201c5cc34d871ee46.tar.bz2 candle-7670fe7d1fa5cacda72c1ab201c5cc34d871ee46.zip |
neon optimized q8k multiplication. (#1021)
* neon optimized q8k multiplication.
* Bugfixes.
* simdification.
-rw-r--r-- | candle-core/src/quantized/k_quants.rs | 10 | ||||
-rw-r--r-- | candle-core/src/quantized/neon.rs | 29 |
2 files changed, 36 insertions, 3 deletions
diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index ac3f7def..80d36555 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -1756,14 +1756,18 @@ impl GgmlType for BlockQ8K { const BLCK_SIZE: usize = QK_K; type VecDotType = BlockQ8K; + #[allow(unreachable_code)] fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> { + #[cfg(target_feature = "neon")] + return super::neon::vec_dot_q8k_q8k(n, xs, ys); + Self::vec_dot_unopt(n, xs, ys) } fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> { - let qk = QK8_0; - if n % QK8_0 != 0 { - crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}") + let qk = QK_K; + if n % QK_K != 0 { + crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}") } // Generic implementation. diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs index 7f76dadc..fd4c1388 100644 --- a/candle-core/src/quantized/neon.rs +++ b/candle-core/src/quantized/neon.rs @@ -149,6 +149,35 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> } #[inline(always)] +pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> { + let qk = QK_K; + if n % QK_K != 0 { + crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}") + } + + let mut sumf = 0f32; + for (xs, ys) in xs.iter().zip(ys.iter()) { + unsafe { + let mut sum_i = vdupq_n_s32(0); + let scale = xs.d * ys.d; + let xs = xs.qs.as_ptr(); + let ys = ys.qs.as_ptr(); + for i in (0..QK_K).step_by(16) { + let xs = vld1q_s8(xs.add(i)); + let ys = vld1q_s8(ys.add(i)); + let xy_lo = vmull_s8(vget_low_s8(xs), vget_low_s8(ys)); + let xy_up = vmull_s8(vget_high_s8(xs), vget_high_s8(ys)); + + let xy = vaddq_s32(vpaddlq_s16(xy_lo), vpaddlq_s16(xy_up)); + sum_i = vaddq_s32(sum_i, xy) + } + sumf += vaddvq_s32(sum_i) as f32 * scale + } + } + Ok(sumf) +} + +#[inline(always)] pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> { if n % QK_K != 0 { crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}") |