summaryrefslogtreecommitdiff
path: root/candle-core/src/quantized/simd128.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/quantized/simd128.rs')
-rw-r--r--candle-core/src/quantized/simd128.rs112
1 files changed, 72 insertions, 40 deletions
diff --git a/candle-core/src/quantized/simd128.rs b/candle-core/src/quantized/simd128.rs
index 061421c4..bddeda7e 100644
--- a/candle-core/src/quantized/simd128.rs
+++ b/candle-core/src/quantized/simd128.rs
@@ -102,53 +102,85 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
#[inline(always)]
pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
- crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
+ crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
}
- let mut sumf = 0.0;
- for (x, y) in xs.iter().zip(ys.iter()) {
- let mut q2: &[_] = &x.qs;
- let mut q8: &[_] = &y.qs;
- let sc = &x.scales;
-
- let mut summs = 0;
- for (bsum, scale) in y.bsums.iter().zip(sc) {
- summs += *bsum as i32 * ((scale >> 4) as i32);
- }
+ unsafe {
+ let mut sumf = f32x4_splat(0f32);
+ for (x, y) in xs.iter().zip(ys.iter()) {
+ let mut q2: &[_] = &x.qs;
+ let mut q8: &[_] = &y.qs;
+ let sc = &x.scales;
- let dall = y.d * x.d.to_f32();
- let dmin = y.d * x.dmin.to_f32();
-
- let mut isum = 0;
- let mut is = 0;
- let mut d;
- for _ in 0..(QK_K / 128) {
- let mut shift = 0;
- for _ in 0..4 {
- d = (sc[is] & 0xF) as i32;
- is += 1;
- let mut isuml = 0;
- for l in 0..16 {
- isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32);
- }
- isum += d * isuml;
- d = (sc[is] & 0xF) as i32;
- is += 1;
- isuml = 0;
- for l in 16..32 {
- isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32);
+ let mut summs = i32x4_splat(0);
+ for i in (0..(QK_K / 16)).step_by(4) {
+ let bsums = i32x4_load_extend_i16x4(y.bsums.as_ptr().add(i));
+ let scales = i32x4_shr(
+ i32x4(
+ sc[i] as i32,
+ sc[i + 1] as i32,
+ sc[i + 2] as i32,
+ sc[i + 3] as i32,
+ ),
+ 4,
+ );
+ summs = i32x4_add(summs, i32x4_mul(bsums, scales))
+ }
+ let summs = f32x4_convert_i32x4(summs);
+
+ let dall = y.d * x.d.to_f32();
+ let dmin = y.d * x.dmin.to_f32();
+
+ let mut isum = i32x4_splat(0);
+ let mut is = 0;
+ for _ in 0..(QK_K / 128) {
+ let mut shift = 0;
+ for _ in 0..4 {
+ let d = (sc[is] & 0xF) as i32;
+ is += 1;
+ let mut isuml = i16x8_splat(0);
+ for l in (0..16).step_by(8) {
+ let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(l));
+ let q2 = i16x8_load_extend_u8x8(q2.as_ptr().add(l));
+ let q2 = v128_and(i16x8_shr(q2, shift), i16x8_splat(3));
+ isuml = i16x8_add(isuml, i16x8_mul(q2, q8))
+ }
+ let dd = i32x4_splat(d);
+ isum = i32x4_add(isum, i32x4_mul(i32x4_extend_low_i16x8(isuml), dd));
+ isum = i32x4_add(isum, i32x4_mul(i32x4_extend_high_i16x8(isuml), dd));
+ let d = (sc[is] & 0xF) as i32;
+ is += 1;
+ let mut isuml = i16x8_splat(0);
+ for l in (16..32).step_by(8) {
+ let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(l));
+ let q2 = i16x8_load_extend_u8x8(q2.as_ptr().add(l));
+ let q2 = v128_and(i16x8_shr(q2, shift), i16x8_splat(3));
+ isuml = i16x8_add(isuml, i16x8_mul(q2, q8))
+ }
+ let dd = i32x4_splat(d);
+ isum = i32x4_add(isum, i32x4_mul(i32x4_extend_low_i16x8(isuml), dd));
+ isum = i32x4_add(isum, i32x4_mul(i32x4_extend_high_i16x8(isuml), dd));
+ shift += 2;
+ // adjust the indexing
+ q8 = &q8[32..];
}
- isum += d * isuml;
- shift += 2;
// adjust the indexing
- q8 = &q8[32..];
+ q2 = &q2[32..];
}
- // adjust the indexing
- q2 = &q2[32..];
+ let isum = f32x4_convert_i32x4(isum);
+ sumf = f32x4_add(
+ sumf,
+ f32x4_sub(
+ f32x4_mul(isum, f32x4_splat(dall)),
+ f32x4_mul(summs, f32x4_splat(dmin)),
+ ),
+ );
}
- sumf += dall * isum as f32 - dmin * summs as f32;
+ let sumf = f32x4_extract_lane::<0>(sumf)
+ + f32x4_extract_lane::<1>(sumf)
+ + f32x4_extract_lane::<2>(sumf)
+ + f32x4_extract_lane::<3>(sumf);
+ Ok(sumf)
}
-
- Ok(sumf)
}
#[inline(always)]