diff options
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/quantized/k_quants.rs | 8 | ||||
-rw-r--r-- | candle-core/src/quantized/simd128.rs | 112 |
2 files changed, 75 insertions, 45 deletions
diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index 9a72d88e..602ea25f 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -710,18 +710,17 @@ impl GgmlType for BlockQ2K { let mut isum = 0; let mut is = 0; - let mut d; for _ in 0..(QK_K / 128) { let mut shift = 0; for _ in 0..4 { - d = (sc[is] & 0xF) as i32; + let d = (sc[is] & 0xF) as i32; is += 1; let mut isuml = 0; for l in 0..16 { isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32); } isum += d * isuml; - d = (sc[is] & 0xF) as i32; + let d = (sc[is] & 0xF) as i32; is += 1; isuml = 0; for l in 16..32 { @@ -1086,7 +1085,6 @@ impl GgmlType for BlockQ3K { let d_all = block.d.to_f32(); let mut m = 1; let mut is = 0; - let mut dl; // Dequantize both 128 long blocks // 32 qs values per 128 long block @@ -1097,7 +1095,7 @@ impl GgmlType for BlockQ3K { for (scale_index, scale_scoped_y) in shift_scoped_y.chunks_exact_mut(16).enumerate() { - dl = d_all * (scales[is] as f32 - 32.0); + let dl = d_all * (scales[is] as f32 - 32.0); for (i, inner_y) in scale_scoped_y.iter_mut().enumerate() { let new_y = dl * (((qs[i + 16 * scale_index] >> shift) & 3) as i8 diff --git a/candle-core/src/quantized/simd128.rs b/candle-core/src/quantized/simd128.rs index 061421c4..bddeda7e 100644 --- a/candle-core/src/quantized/simd128.rs +++ b/candle-core/src/quantized/simd128.rs @@ -102,53 +102,85 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> #[inline(always)] pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> { if n % QK_K != 0 { - crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}") + crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}") } - let mut sumf = 0.0; - for (x, y) in xs.iter().zip(ys.iter()) { - let mut q2: &[_] = &x.qs; - let mut q8: &[_] = &y.qs; - let sc = &x.scales; - - let mut summs = 0; - for (bsum, scale) in y.bsums.iter().zip(sc) { - summs += *bsum as i32 * ((scale >> 4) as i32); - } + unsafe { + let mut sumf = f32x4_splat(0f32); + for (x, y) in xs.iter().zip(ys.iter()) { + let mut q2: &[_] = &x.qs; + let mut q8: &[_] = &y.qs; + let sc = &x.scales; - let dall = y.d * x.d.to_f32(); - let dmin = y.d * x.dmin.to_f32(); - - let mut isum = 0; - let mut is = 0; - let mut d; - for _ in 0..(QK_K / 128) { - let mut shift = 0; - for _ in 0..4 { - d = (sc[is] & 0xF) as i32; - is += 1; - let mut isuml = 0; - for l in 0..16 { - isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32); - } - isum += d * isuml; - d = (sc[is] & 0xF) as i32; - is += 1; - isuml = 0; - for l in 16..32 { - isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32); + let mut summs = i32x4_splat(0); + for i in (0..(QK_K / 16)).step_by(4) { + let bsums = i32x4_load_extend_i16x4(y.bsums.as_ptr().add(i)); + let scales = i32x4_shr( + i32x4( + sc[i] as i32, + sc[i + 1] as i32, + sc[i + 2] as i32, + sc[i + 3] as i32, + ), + 4, + ); + summs = i32x4_add(summs, i32x4_mul(bsums, scales)) + } + let summs = f32x4_convert_i32x4(summs); + + let dall = y.d * x.d.to_f32(); + let dmin = y.d * x.dmin.to_f32(); + + let mut isum = i32x4_splat(0); + let mut is = 0; + for _ in 0..(QK_K / 128) { + let mut shift = 0; + for _ in 0..4 { + let d = (sc[is] & 0xF) as i32; + is += 1; + let mut isuml = i16x8_splat(0); + for l in (0..16).step_by(8) { + let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(l)); + let q2 = i16x8_load_extend_u8x8(q2.as_ptr().add(l)); + let q2 = v128_and(i16x8_shr(q2, shift), i16x8_splat(3)); + isuml = i16x8_add(isuml, i16x8_mul(q2, q8)) + } + let dd = i32x4_splat(d); + isum = i32x4_add(isum, i32x4_mul(i32x4_extend_low_i16x8(isuml), dd)); + isum = i32x4_add(isum, i32x4_mul(i32x4_extend_high_i16x8(isuml), dd)); + let d = (sc[is] & 0xF) as i32; + is += 1; + let mut isuml = i16x8_splat(0); + for l in (16..32).step_by(8) { + let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(l)); + let q2 = i16x8_load_extend_u8x8(q2.as_ptr().add(l)); + let q2 = v128_and(i16x8_shr(q2, shift), i16x8_splat(3)); + isuml = i16x8_add(isuml, i16x8_mul(q2, q8)) + } + let dd = i32x4_splat(d); + isum = i32x4_add(isum, i32x4_mul(i32x4_extend_low_i16x8(isuml), dd)); + isum = i32x4_add(isum, i32x4_mul(i32x4_extend_high_i16x8(isuml), dd)); + shift += 2; + // adjust the indexing + q8 = &q8[32..]; } - isum += d * isuml; - shift += 2; // adjust the indexing - q8 = &q8[32..]; + q2 = &q2[32..]; } - // adjust the indexing - q2 = &q2[32..]; + let isum = f32x4_convert_i32x4(isum); + sumf = f32x4_add( + sumf, + f32x4_sub( + f32x4_mul(isum, f32x4_splat(dall)), + f32x4_mul(summs, f32x4_splat(dmin)), + ), + ); } - sumf += dall * isum as f32 - dmin * summs as f32; + let sumf = f32x4_extract_lane::<0>(sumf) + + f32x4_extract_lane::<1>(sumf) + + f32x4_extract_lane::<2>(sumf) + + f32x4_extract_lane::<3>(sumf); + Ok(sumf) } - - Ok(sumf) } #[inline(always)] |