diff options
-rw-r--r-- | candle-core/src/quantized/avx.rs | 20 | ||||
-rw-r--r-- | candle-core/src/quantized/k_quants.rs | 4 |
2 files changed, 23 insertions, 1 deletions
diff --git a/candle-core/src/quantized/avx.rs b/candle-core/src/quantized/avx.rs index 9e4ad642..96087feb 100644 --- a/candle-core/src/quantized/avx.rs +++ b/candle-core/src/quantized/avx.rs @@ -56,7 +56,6 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> } unsafe { - // Generic implementation. let mut acc = _mm256_setzero_ps(); for (x, y) in xs.iter().zip(ys.iter()) { let d = _mm256_set1_ps(f16::to_f32(x.d) * f16::to_f32(y.d)); @@ -71,6 +70,25 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> } } +#[inline(always)] +pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result<f32> { + let qk = QK8_0; + if n % QK8_0 != 0 { + crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}") + } + unsafe { + let mut acc = _mm256_setzero_ps(); + for (x, y) in xs.iter().zip(ys.iter()) { + let d = _mm256_set1_ps(f16::to_f32(x.d) * f16::to_f32(y.d)); + let bx = _mm256_loadu_si256(x.qs.as_ptr() as *const __m256i); + let by = _mm256_loadu_si256(y.qs.as_ptr() as *const __m256i); + let q = mul_sum_i8_pairs_float(bx, by); + acc = _mm256_fmadd_ps(d, q, acc); + } + Ok(hsum_float_8(acc)) + } +} + const K_SHUFFLE: [u8; 128] = [ 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index 36efe2f2..02022480 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -421,7 +421,11 @@ impl GgmlType for BlockQ8_0 { Ok(()) } + #[allow(unreachable_code)] fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> { + #[cfg(target_feature = "avx")] + return super::avx::vec_dot_q8_0_q8_0(n, xs, ys); + let qk = QK8_0; if n % QK8_0 != 0 { crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}") |