diff options
-rw-r--r-- | candle-core/src/quantized/avx.rs | 399 | ||||
-rw-r--r-- | candle-core/src/quantized/k_quants.rs | 12 |
2 files changed, 403 insertions, 8 deletions
diff --git a/candle-core/src/quantized/avx.rs b/candle-core/src/quantized/avx.rs index f906d090..e5fa058d 100644 --- a/candle-core/src/quantized/avx.rs +++ b/candle-core/src/quantized/avx.rs @@ -1,4 +1,6 @@ -use super::k_quants::{BlockQ4K, BlockQ4_0, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K}; +use super::k_quants::{ + BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K, +}; use crate::Result; use byteorder::{ByteOrder, LittleEndian}; use half::f16; @@ -120,6 +122,18 @@ unsafe fn get_scale_shuffle_k4(i: usize) -> __m256i { } #[inline(always)] +unsafe fn get_scale_shuffle_q3k(i: usize) -> __m256i { + const K_SHUFFLE: [u8; 128] = [ + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 2, 3, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 6, 7, 6, 7, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10, 11, 10, 11, 10, 11, 10, 11, + 10, 11, 10, 11, 10, 11, 10, 11, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, 13, 12, + 13, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, + ]; + _mm256_loadu_si256((K_SHUFFLE.as_ptr() as *const __m256i).add(i)) +} + +#[inline(always)] pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> { let qk = QK_K; if n % qk != 0 { @@ -212,15 +226,272 @@ unsafe fn mm256_set_m128i(a: __m128i, b: __m128i) -> __m256i { _mm256_insertf128_si256(_mm256_castsi128_si256(b), a, 1) } +#[cfg_attr(not(debug_assertions), inline(always))] +pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> { + if n % QK_K != 0 { + crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}") + } + + unsafe { + let m3 = _mm256_set1_epi8(3); + let m4 = _mm_set1_epi8(0xF); + + let mut acc = _mm256_setzero_ps(); + + for (x, y) in xs.iter().zip(ys.iter()) { + let d = y.d * x.d.to_f32(); + let dmin = -y.d * x.dmin.to_f32(); + + let mut q2 = x.qs.as_ptr(); + let mut q8 = y.qs.as_ptr(); + + let mins_and_scales = _mm_loadu_si128(x.scales.as_ptr() as *const __m128i); + let scales8 = _mm_and_si128(mins_and_scales, m4); + let mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); + let mins = _mm256_cvtepi8_epi16(mins8); + let prod = + _mm256_madd_epi16(mins, _mm256_loadu_si256(y.bsums.as_ptr() as *const __m256i)); + + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc); + + let all_scales = _mm256_cvtepi8_epi16(scales8); + let l_scales = _mm256_extracti128_si256(all_scales, 0); + let h_scales = _mm256_extracti128_si256(all_scales, 1); + let scales = [ + mm256_set_m128i(l_scales, l_scales), + mm256_set_m128i(h_scales, h_scales), + ]; + + let mut sumi = _mm256_setzero_si256(); + + for scale in scales { + let q2bits = _mm256_loadu_si256(q2 as *const __m256i); + q2 = q2.add(32); + + let q8_0 = _mm256_loadu_si256(q8 as *const __m256i); + q8 = q8.add(32); + let q8_1 = _mm256_loadu_si256(q8 as *const __m256i); + q8 = q8.add(32); + let q8_2 = _mm256_loadu_si256(q8 as *const __m256i); + q8 = q8.add(32); + let q8_3 = _mm256_loadu_si256(q8 as *const __m256i); + q8 = q8.add(32); + + let q2_0 = _mm256_and_si256(q2bits, m3); + let q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3); + let q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3); + let q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3); + + let mut p0 = _mm256_maddubs_epi16(q2_0, q8_0); + let mut p1 = _mm256_maddubs_epi16(q2_1, q8_1); + let mut p2 = _mm256_maddubs_epi16(q2_2, q8_2); + let mut p3 = _mm256_maddubs_epi16(q2_3, q8_3); + + p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(0)), p0); + p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(1)), p1); + p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(2)), p2); + p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scale, get_scale_shuffle_q3k(3)), p3); + + p0 = _mm256_add_epi32(p0, p1); + p2 = _mm256_add_epi32(p2, p3); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2)); + } + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + } + + Ok(hsum_float_8(acc)) + } +} + +#[cfg_attr(not(debug_assertions), inline(always))] +pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Result<f32> { + if n % QK_K != 0 { + crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}") + } + + const KMASK1: u32 = 0x03030303; + const KMASK2: u32 = 0x0f0f0f0f; + + let mut aux = [0u32; 3]; + + unsafe { + let m3 = _mm256_set1_epi8(3); + let mone = _mm256_set1_epi8(1); + let m32 = _mm_set1_epi8(32); + + let mut acc = _mm256_setzero_ps(); + for (x, y) in xs.iter().zip(ys.iter()) { + let d = y.d * x.d.to_f32(); + + let mut q3 = x.qs.as_ptr(); + let mut q8 = y.qs.as_ptr(); + + LittleEndian::read_u32_into(&x.scales, &mut aux); + let mut scales128 = _mm_set_epi32( + (((aux[1] >> 4) & KMASK2) | (((aux[2] >> 6) & KMASK1) << 4)) as i32, + (((aux[0] >> 4) & KMASK2) | (((aux[2] >> 4) & KMASK1) << 4)) as i32, + ((aux[1] & KMASK2) | (((aux[2] >> 2) & KMASK1) << 4)) as i32, + ((aux[0] & KMASK2) | (((aux[2]) & KMASK1) << 4)) as i32, + ); + scales128 = _mm_sub_epi8(scales128, m32); + let all_scales = _mm256_cvtepi8_epi16(scales128); + let l_scales = _mm256_extracti128_si256(all_scales, 0); + let h_scales = _mm256_extracti128_si256(all_scales, 1); + let scales = [ + mm256_set_m128i(l_scales, l_scales), + mm256_set_m128i(h_scales, h_scales), + ]; + + // high bit + let hbits = _mm256_loadu_si256(x.hmask.as_ptr() as *const __m256i); + + // integer accumulator + let mut sumi = _mm256_setzero_si256(); + + for (j, scale) in scales.iter().enumerate() { + // load low 2 bits + let q3bits = _mm256_loadu_si256(q3 as *const __m256i); + q3 = q3.add(32); + + // prepare low and high bits + //We hardcode the shifts here to avoid loading them into a seperate register + let q3l_0 = _mm256_and_si256(q3bits, m3); + let q3h_0 = if j == 0 { + _mm256_slli_epi16( + _mm256_srli_epi16( + _mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 0)), + 0, + ), + 2, + ) + } else { + _mm256_slli_epi16( + _mm256_srli_epi16( + _mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 4)), + 4, + ), + 2, + ) + }; + + let q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3); + let q3h_1 = if j == 0 { + _mm256_slli_epi16( + _mm256_srli_epi16( + _mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 1)), + 1, + ), + 2, + ) + } else { + _mm256_slli_epi16( + _mm256_srli_epi16( + _mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 5)), + 5, + ), + 2, + ) + }; + + let q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3); + let q3h_2 = if j == 0 { + _mm256_slli_epi16( + _mm256_srli_epi16( + _mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 2)), + 2, + ), + 2, + ) + } else { + _mm256_slli_epi16( + _mm256_srli_epi16( + _mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 6)), + 6, + ), + 2, + ) + }; + + let q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3); + let q3h_3 = if j == 0 { + _mm256_slli_epi16( + _mm256_srli_epi16( + _mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 3)), + 3, + ), + 2, + ) + } else { + _mm256_slli_epi16( + _mm256_srli_epi16( + _mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 7)), + 7, + ), + 2, + ) + }; + + // load Q8 quants + let q8_0 = _mm256_loadu_si256(q8 as *const __m256i); + q8 = q8.add(32); + let q8_1 = _mm256_loadu_si256(q8 as *const __m256i); + q8 = q8.add(32); + let q8_2 = _mm256_loadu_si256(q8 as *const __m256i); + q8 = q8.add(32); + let q8_3 = _mm256_loadu_si256(q8 as *const __m256i); + q8 = q8.add(32); + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + let q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0); + let q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1); + let q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2); + let q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3); + + let mut p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0); + let mut p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1); + let mut p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2); + let mut p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3); + + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + p16_2 = _mm256_sub_epi16(p16_2, q8s_2); + p16_3 = _mm256_sub_epi16(p16_3, q8s_3); + + // multiply with scales + p16_0 = + _mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(0)), p16_0); + p16_1 = + _mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(1)), p16_1); + p16_2 = + _mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(2)), p16_2); + p16_3 = + _mm256_madd_epi16(_mm256_shuffle_epi8(*scale, get_scale_shuffle_q3k(3)), p16_3); + + // accumulate + p16_0 = _mm256_add_epi32(p16_0, p16_1); + p16_2 = _mm256_add_epi32(p16_2, p16_3); + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2)); + } + + // multiply with block scale and accumulate + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + } + Ok(hsum_float_8(acc)) + } +} + #[inline(always)] pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> { if n % QK_K != 0 { crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}") } let mut utmp = [0u32; 4]; - let kmask1: u32 = 0x3f3f3f3f; - let kmask2: u32 = 0x0f0f0f0f; - let kmask3: u32 = 0x03030303; + const KMASK1: u32 = 0x3f3f3f3f; + const KMASK2: u32 = 0x0f0f0f0f; + const KMASK3: u32 = 0x03030303; unsafe { let m4 = _mm256_set1_epi8(0xF); @@ -234,11 +505,11 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - let uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4); + let uaux = utmp[1] & KMASK1; + utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4); utmp[2] = uaux; - utmp[0] &= kmask1; + utmp[0] &= KMASK1; let mut q4 = x.qs.as_ptr(); let mut q8 = y.qs.as_ptr(); @@ -295,3 +566,115 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res Ok(hsum_float_8(acc) + _mm_cvtss_f32(acc_m)) } } + +#[cfg_attr(not(debug_assertions), inline(always))] +pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Result<f32> { + if n % QK_K != 0 { + crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}") + } + let mut utmp = [0u32; 4]; + const KMASK1: u32 = 0x3f3f3f3f; + const KMASK2: u32 = 0x0f0f0f0f; + const KMASK3: u32 = 0x03030303; + + unsafe { + let m4 = _mm256_set1_epi8(0xF); + let mzero = _mm_setzero_si128(); + let mone = _mm256_set1_epi8(1); + + let mut acc = _mm256_setzero_ps(); + let mut summs = 0.0; + + for (x, y) in xs.iter().zip(ys.iter()) { + let d = y.d * x.d.to_f32(); + let dmin = -y.d * x.dmin.to_f32(); + + LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]); + + utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4); + let uaux = utmp[1] & KMASK1; + utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4); + utmp[2] = uaux; + utmp[0] &= KMASK1; + + let mut q5 = x.qs.as_ptr(); + let mut q8 = y.qs.as_ptr(); + + let mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32( + utmp[3] as i32, + utmp[2] as i32, + utmp[1] as i32, + utmp[0] as i32, + )); + + let q8sums = _mm256_loadu_si256(y.bsums.as_ptr() as *const __m256i); + let q8s = _mm_hadd_epi16( + _mm256_extracti128_si256(q8sums, 0), + _mm256_extracti128_si256(q8sums, 1), + ); + let prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); + let hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); + summs += dmin * _mm_extract_epi32(hsum, 0) as f32; + + let sc128 = _mm256_extracti128_si256(mins_and_scales, 0); + let scales = mm256_set_m128i(sc128, sc128); + + let hbits = _mm256_loadu_si256(x.qh.as_ptr() as *const __m256i); + let mut hmask = mone; + + let mut sumi = _mm256_setzero_si256(); + + for j in 0..QK_K / 64 { + let scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j)); + let scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j + 1)); + + let q5bits = _mm256_loadu_si256(q5 as *const __m256i); + q5 = q5.add(32); + + //Similar to q3k we hardcode the shifts here to avoid loading them into a seperate register + let q5l_0 = _mm256_and_si256(q5bits, m4); + let q5l_0_shift_input = _mm256_and_si256(hbits, hmask); + let q5l_0_right_shift = match j { + 0 => _mm256_srli_epi16(q5l_0_shift_input, 0), + 1 => _mm256_srli_epi16(q5l_0_shift_input, 2), + 2 => _mm256_srli_epi16(q5l_0_shift_input, 4), + 3 => _mm256_srli_epi16(q5l_0_shift_input, 6), + _ => unreachable!(), + }; + let q5h_0 = _mm256_slli_epi16(q5l_0_right_shift, 4); + let q5_0 = _mm256_add_epi8(q5l_0, q5h_0); + hmask = _mm256_slli_epi16(hmask, 1); + + let q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4); + let q5l_1_shift_input = _mm256_and_si256(hbits, hmask); + let q5l_1_right_shift = match j { + 0 => _mm256_srli_epi16(q5l_1_shift_input, 1), + 1 => _mm256_srli_epi16(q5l_1_shift_input, 3), + 2 => _mm256_srli_epi16(q5l_1_shift_input, 5), + 3 => _mm256_srli_epi16(q5l_1_shift_input, 7), + _ => unreachable!(), + }; + + let q5h_1 = _mm256_slli_epi16(q5l_1_right_shift, 4); + let q5_1 = _mm256_add_epi8(q5l_1, q5h_1); + hmask = _mm256_slli_epi16(hmask, 1); + + let q8_0 = _mm256_loadu_si256(q8 as *const __m256i); + q8 = q8.add(32); + let q8_1 = _mm256_loadu_si256(q8 as *const __m256i); + q8 = q8.add(32); + + let mut p16_0 = _mm256_maddubs_epi16(q5_0, q8_0); + let mut p16_1 = _mm256_maddubs_epi16(q5_1, q8_1); + + p16_0 = _mm256_madd_epi16(scale_0, p16_0); + p16_1 = _mm256_madd_epi16(scale_1, p16_1); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); + } + let vd = _mm256_set1_ps(d); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); + } + Ok(hsum_float_8(acc) + summs) + } +} diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index 7f14600b..e7404529 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -671,7 +671,11 @@ impl GgmlType for BlockQ2K { const BLCK_SIZE: usize = QK_K; type VecDotType = BlockQ8K; + #[allow(unreachable_code)] fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> { + #[cfg(target_feature = "avx")] + return super::avx::vec_dot_q2k_q8k(n, xs, ys); + if n % QK_K != 0 { crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}") } @@ -834,7 +838,11 @@ impl GgmlType for BlockQ3K { const BLCK_SIZE: usize = QK_K; type VecDotType = BlockQ8K; + #[allow(unreachable_code)] fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> { + #[cfg(target_feature = "avx")] + return super::avx::vec_dot_q3k_q8k(n, xs, ys); + if n % QK_K != 0 { crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}") } @@ -1288,7 +1296,11 @@ impl GgmlType for BlockQ5K { const BLCK_SIZE: usize = QK_K; type VecDotType = BlockQ8K; + #[allow(unreachable_code)] fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> { + #[cfg(target_feature = "avx")] + return super::avx::vec_dot_q5k_q8k(n, xs, ys); + if n % QK_K != 0 { crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}") } |