summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-30 20:12:41 +0100
committerGitHub <noreply@github.com>2023-09-30 20:12:41 +0100
commit4e55aaa51f61a2472c24c52d7b3596f4d0bca4f7 (patch)
tree8e7076acefdb02d25691a852f296b6120771ff73 /candle-core/src
parentdeee7612da7dcda1aa1cfd4237f4858d9f5ed8c7 (diff)
downloadcandle-4e55aaa51f61a2472c24c52d7b3596f4d0bca4f7.tar.gz
candle-4e55aaa51f61a2472c24c52d7b3596f4d0bca4f7.tar.bz2
candle-4e55aaa51f61a2472c24c52d7b3596f4d0bca4f7.zip
Simd128 version of the q2k-q8k vecdot product. (#1011)
* Sketch the simd128 version of q2k vecdot. * Use a single accumulator. * Simdify the q2k-q8k vecdot product. * Cosmetic change.
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/quantized/k_quants.rs8
-rw-r--r--candle-core/src/quantized/simd128.rs112
2 files changed, 75 insertions, 45 deletions
diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs
index 9a72d88e..602ea25f 100644
--- a/candle-core/src/quantized/k_quants.rs
+++ b/candle-core/src/quantized/k_quants.rs
@@ -710,18 +710,17 @@ impl GgmlType for BlockQ2K {
let mut isum = 0;
let mut is = 0;
- let mut d;
for _ in 0..(QK_K / 128) {
let mut shift = 0;
for _ in 0..4 {
- d = (sc[is] & 0xF) as i32;
+ let d = (sc[is] & 0xF) as i32;
is += 1;
let mut isuml = 0;
for l in 0..16 {
isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32);
}
isum += d * isuml;
- d = (sc[is] & 0xF) as i32;
+ let d = (sc[is] & 0xF) as i32;
is += 1;
isuml = 0;
for l in 16..32 {
@@ -1086,7 +1085,6 @@ impl GgmlType for BlockQ3K {
let d_all = block.d.to_f32();
let mut m = 1;
let mut is = 0;
- let mut dl;
// Dequantize both 128 long blocks
// 32 qs values per 128 long block
@@ -1097,7 +1095,7 @@ impl GgmlType for BlockQ3K {
for (scale_index, scale_scoped_y) in
shift_scoped_y.chunks_exact_mut(16).enumerate()
{
- dl = d_all * (scales[is] as f32 - 32.0);
+ let dl = d_all * (scales[is] as f32 - 32.0);
for (i, inner_y) in scale_scoped_y.iter_mut().enumerate() {
let new_y = dl
* (((qs[i + 16 * scale_index] >> shift) & 3) as i8
diff --git a/candle-core/src/quantized/simd128.rs b/candle-core/src/quantized/simd128.rs
index 061421c4..bddeda7e 100644
--- a/candle-core/src/quantized/simd128.rs
+++ b/candle-core/src/quantized/simd128.rs
@@ -102,53 +102,85 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
#[inline(always)]
pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
- crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
+ crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
}
- let mut sumf = 0.0;
- for (x, y) in xs.iter().zip(ys.iter()) {
- let mut q2: &[_] = &x.qs;
- let mut q8: &[_] = &y.qs;
- let sc = &x.scales;
-
- let mut summs = 0;
- for (bsum, scale) in y.bsums.iter().zip(sc) {
- summs += *bsum as i32 * ((scale >> 4) as i32);
- }
+ unsafe {
+ let mut sumf = f32x4_splat(0f32);
+ for (x, y) in xs.iter().zip(ys.iter()) {
+ let mut q2: &[_] = &x.qs;
+ let mut q8: &[_] = &y.qs;
+ let sc = &x.scales;
- let dall = y.d * x.d.to_f32();
- let dmin = y.d * x.dmin.to_f32();
-
- let mut isum = 0;
- let mut is = 0;
- let mut d;
- for _ in 0..(QK_K / 128) {
- let mut shift = 0;
- for _ in 0..4 {
- d = (sc[is] & 0xF) as i32;
- is += 1;
- let mut isuml = 0;
- for l in 0..16 {
- isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32);
- }
- isum += d * isuml;
- d = (sc[is] & 0xF) as i32;
- is += 1;
- isuml = 0;
- for l in 16..32 {
- isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32);
+ let mut summs = i32x4_splat(0);
+ for i in (0..(QK_K / 16)).step_by(4) {
+ let bsums = i32x4_load_extend_i16x4(y.bsums.as_ptr().add(i));
+ let scales = i32x4_shr(
+ i32x4(
+ sc[i] as i32,
+ sc[i + 1] as i32,
+ sc[i + 2] as i32,
+ sc[i + 3] as i32,
+ ),
+ 4,
+ );
+ summs = i32x4_add(summs, i32x4_mul(bsums, scales))
+ }
+ let summs = f32x4_convert_i32x4(summs);
+
+ let dall = y.d * x.d.to_f32();
+ let dmin = y.d * x.dmin.to_f32();
+
+ let mut isum = i32x4_splat(0);
+ let mut is = 0;
+ for _ in 0..(QK_K / 128) {
+ let mut shift = 0;
+ for _ in 0..4 {
+ let d = (sc[is] & 0xF) as i32;
+ is += 1;
+ let mut isuml = i16x8_splat(0);
+ for l in (0..16).step_by(8) {
+ let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(l));
+ let q2 = i16x8_load_extend_u8x8(q2.as_ptr().add(l));
+ let q2 = v128_and(i16x8_shr(q2, shift), i16x8_splat(3));
+ isuml = i16x8_add(isuml, i16x8_mul(q2, q8))
+ }
+ let dd = i32x4_splat(d);
+ isum = i32x4_add(isum, i32x4_mul(i32x4_extend_low_i16x8(isuml), dd));
+ isum = i32x4_add(isum, i32x4_mul(i32x4_extend_high_i16x8(isuml), dd));
+ let d = (sc[is] & 0xF) as i32;
+ is += 1;
+ let mut isuml = i16x8_splat(0);
+ for l in (16..32).step_by(8) {
+ let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(l));
+ let q2 = i16x8_load_extend_u8x8(q2.as_ptr().add(l));
+ let q2 = v128_and(i16x8_shr(q2, shift), i16x8_splat(3));
+ isuml = i16x8_add(isuml, i16x8_mul(q2, q8))
+ }
+ let dd = i32x4_splat(d);
+ isum = i32x4_add(isum, i32x4_mul(i32x4_extend_low_i16x8(isuml), dd));
+ isum = i32x4_add(isum, i32x4_mul(i32x4_extend_high_i16x8(isuml), dd));
+ shift += 2;
+ // adjust the indexing
+ q8 = &q8[32..];
}
- isum += d * isuml;
- shift += 2;
// adjust the indexing
- q8 = &q8[32..];
+ q2 = &q2[32..];
}
- // adjust the indexing
- q2 = &q2[32..];
+ let isum = f32x4_convert_i32x4(isum);
+ sumf = f32x4_add(
+ sumf,
+ f32x4_sub(
+ f32x4_mul(isum, f32x4_splat(dall)),
+ f32x4_mul(summs, f32x4_splat(dmin)),
+ ),
+ );
}
- sumf += dall * isum as f32 - dmin * summs as f32;
+ let sumf = f32x4_extract_lane::<0>(sumf)
+ + f32x4_extract_lane::<1>(sumf)
+ + f32x4_extract_lane::<2>(sumf)
+ + f32x4_extract_lane::<3>(sumf);
+ Ok(sumf)
}
-
- Ok(sumf)
}
#[inline(always)]