summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-25 10:14:49 +0100
committerGitHub <noreply@github.com>2023-08-25 10:14:49 +0100
commitafc10a3232b218dcdb8c3b0989f1066940ea992b (patch)
tree0648bca15ca71d1e78c61ae79afa2d039447b247 /candle-core/src
parentd728e646c20e773498b859fe41f4109f86320ca6 (diff)
downloadcandle-afc10a3232b218dcdb8c3b0989f1066940ea992b.tar.gz
candle-afc10a3232b218dcdb8c3b0989f1066940ea992b.tar.bz2
candle-afc10a3232b218dcdb8c3b0989f1066940ea992b.zip
AVX version for the q8-0 multiplications. (#598)
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/quantized/avx.rs20
-rw-r--r--candle-core/src/quantized/k_quants.rs4
2 files changed, 23 insertions, 1 deletions
diff --git a/candle-core/src/quantized/avx.rs b/candle-core/src/quantized/avx.rs
index 9e4ad642..96087feb 100644
--- a/candle-core/src/quantized/avx.rs
+++ b/candle-core/src/quantized/avx.rs
@@ -56,7 +56,6 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
}
unsafe {
- // Generic implementation.
let mut acc = _mm256_setzero_ps();
for (x, y) in xs.iter().zip(ys.iter()) {
let d = _mm256_set1_ps(f16::to_f32(x.d) * f16::to_f32(y.d));
@@ -71,6 +70,25 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
}
}
+#[inline(always)]
+pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result<f32> {
+ let qk = QK8_0;
+ if n % QK8_0 != 0 {
+ crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
+ }
+ unsafe {
+ let mut acc = _mm256_setzero_ps();
+ for (x, y) in xs.iter().zip(ys.iter()) {
+ let d = _mm256_set1_ps(f16::to_f32(x.d) * f16::to_f32(y.d));
+ let bx = _mm256_loadu_si256(x.qs.as_ptr() as *const __m256i);
+ let by = _mm256_loadu_si256(y.qs.as_ptr() as *const __m256i);
+ let q = mul_sum_i8_pairs_float(bx, by);
+ acc = _mm256_fmadd_ps(d, q, acc);
+ }
+ Ok(hsum_float_8(acc))
+ }
+}
+
const K_SHUFFLE: [u8; 128] = [
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,
diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs
index 36efe2f2..02022480 100644
--- a/candle-core/src/quantized/k_quants.rs
+++ b/candle-core/src/quantized/k_quants.rs
@@ -421,7 +421,11 @@ impl GgmlType for BlockQ8_0 {
Ok(())
}
+ #[allow(unreachable_code)]
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
+ #[cfg(target_feature = "avx")]
+ return super::avx::vec_dot_q8_0_q8_0(n, xs, ys);
+
let qk = QK8_0;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")