use super::k_quants::{ BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K, }; use crate::Result; use byteorder::{ByteOrder, LittleEndian}; #[allow(unused_imports)] #[cfg(target_arch = "arm")] use core::arch::arm::*; #[allow(unused_imports)] #[cfg(target_arch = "aarch64")] use core::arch::aarch64::*; #[inline(always)] unsafe fn vdotq_s32(a: int8x16_t, b: int8x16_t) -> int32x4_t { // TODO: dotprod let p0 = vmull_s8(vget_low_s8(a), vget_low_s8(b)); let p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b)); vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)) } #[inline(always)] pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result { let qk = QK8_0; let nb = n / qk; if n % QK8_0 != 0 { crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}") } unsafe { let mut sumv0 = vdupq_n_f32(0.0f32); for i in 0..nb { let x0 = &xs[i]; let y0 = &ys[i]; let m4b = vdupq_n_u8(0x0F); let s8b = vdupq_n_s8(0x8); let v0_0 = vld1q_u8(x0.qs.as_ptr()); // 4-bit -> 8-bit let v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b)); let v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); // sub 8 let v0_0ls = vsubq_s8(v0_0l, s8b); let v0_0hs = vsubq_s8(v0_0h, s8b); // load y let v1_0l = vld1q_s8(y0.qs.as_ptr()); let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16)); let pl0 = vdotq_s32(v0_0ls, v1_0l); let ph0 = vdotq_s32(v0_0hs, v1_0h); sumv0 = vmlaq_n_f32( sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0.d.to_f32() * y0.d.to_f32(), ); } Ok(vaddvq_f32(sumv0)) } } #[inline(always)] pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result { let qk = QK8_0; if n % QK8_0 != 0 { crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}") } let nb = n / QK8_0; unsafe { let mut sumv0 = vdupq_n_f32(0.0f32); for i in 0..nb { let x0 = &xs[i]; let y0 = &ys[i]; let x0_0 = vld1q_s8(x0.qs.as_ptr()); let x0_1 = vld1q_s8(x0.qs.as_ptr().add(16)); // load y let y0_0 = vld1q_s8(y0.qs.as_ptr()); let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16)); let p0 = vdotq_s32(x0_0, y0_0); let p1 = vdotq_s32(x0_1, y0_1); sumv0 = vmlaq_n_f32( sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), x0.d.to_f32() * y0.d.to_f32(), ); } Ok(vaddvq_f32(sumv0)) } } #[inline(always)] pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result { let qk = QK_K; if n % QK_K != 0 { crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}") } let mut sumf = 0f32; for (xs, ys) in xs.iter().zip(ys.iter()) { unsafe { let mut sum_i = vdupq_n_s32(0); let scale = xs.d * ys.d; let xs = xs.qs.as_ptr(); let ys = ys.qs.as_ptr(); for i in (0..QK_K).step_by(16) { let xs = vld1q_s8(xs.add(i)); let ys = vld1q_s8(ys.add(i)); let xy = vdotq_s32(xs, ys); sum_i = vaddq_s32(sum_i, xy) } sumf += vaddvq_s32(sum_i) as f32 * scale } } Ok(sumf) } #[inline(always)] pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result { if n % QK_K != 0 { crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}") } let mut sum = 0f32; unsafe { let m4b = vdupq_n_u8(0xF); let mone = vdupq_n_u8(3); for (x, y) in xs.iter().zip(ys.iter()) { let d_all = x.d.to_f32(); let mut q6 = x.ql.as_ptr(); let mut qh = x.qh.as_ptr(); let mut q8 = y.qs.as_ptr(); let mut scale = x.scales.as_ptr(); let q8sums = vld1q_s16_x2(y.bsums.as_ptr()); let scales = vld1q_s8(scale); let q6scales = int16x8x2_t( vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales)), ); let prod = vaddq_s32( vaddq_s32( vmull_s16(vget_low_s16(q8sums.0), vget_low_s16(q6scales.0)), vmull_s16(vget_high_s16(q8sums.0), vget_high_s16(q6scales.0)), ), vaddq_s32( vmull_s16(vget_low_s16(q8sums.1), vget_low_s16(q6scales.1)), vmull_s16(vget_high_s16(q8sums.1), vget_high_s16(q6scales.1)), ), ); let isum_mins = vaddvq_s32(prod); let mut isum = 0i32; for _j in 0..QK_K / 128 { let qhbits = vld1q_u8_x2(qh); qh = qh.add(32); let q6bits = vld1q_u8_x4(q6); q6 = q6.add(64); let q8bytes = vld1q_s8_x4(q8); q8 = q8.add(64); let q6h_0 = vshlq_n_u8(vandq_u8(mone, qhbits.0), 4); let q6h_1 = vshlq_n_u8(vandq_u8(mone, qhbits.1), 4); let shifted = vshrq_n_u8(qhbits.0, 2); let q6h_2 = vshlq_n_u8(vandq_u8(mone, shifted), 4); let shifted = vshrq_n_u8(qhbits.1, 2); let q6h_3 = vshlq_n_u8(vandq_u8(mone, shifted), 4); let q6bytes_0 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.0, m4b), q6h_0)); let q6bytes_1 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.1, m4b), q6h_1)); let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2)); let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3)); let p0 = vdotq_s32(q6bytes_0, q8bytes.0); let p1 = vdotq_s32(q6bytes_1, q8bytes.1); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1; scale = scale.add(2); let p2 = vdotq_s32(q6bytes_2, q8bytes.2); let p3 = vdotq_s32(q6bytes_3, q8bytes.3); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1; scale = scale.add(2); let q8bytes = vld1q_s8_x4(q8); q8 = q8.add(64); let shifted = vshrq_n_u8(qhbits.0, 4); let q6h_0 = vshlq_n_u8(vandq_u8(mone, shifted), 4); let shifted = vshrq_n_u8(qhbits.1, 4); let q6h_1 = vshlq_n_u8(vandq_u8(mone, shifted), 4); let shifted = vshrq_n_u8(qhbits.0, 6); let q6h_2 = vshlq_n_u8(vandq_u8(mone, shifted), 4); let shifted = vshrq_n_u8(qhbits.1, 6); let q6h_3 = vshlq_n_u8(vandq_u8(mone, shifted), 4); let q6bytes_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.0, 4), q6h_0)); let q6bytes_1 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.1, 4), q6h_1)); let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2)); let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3)); let p0 = vdotq_s32(q6bytes_0, q8bytes.0); let p1 = vdotq_s32(q6bytes_1, q8bytes.1); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1; scale = scale.add(2); let p2 = vdotq_s32(q6bytes_2, q8bytes.2); let p3 = vdotq_s32(q6bytes_3, q8bytes.3); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1; scale = scale.add(2); } sum += d_all * y.d * ((isum - 32 * isum_mins) as f32); } } Ok(sum) } #[inline(always)] pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Result { if n % QK_K != 0 { crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}") } let mut sumf = 0f32; let mut utmp = [0u32; 4]; const KMASK1: u32 = 0x3f3f3f3f; const KMASK2: u32 = 0x0f0f0f0f; const KMASK3: u32 = 0x03030303; unsafe { let m4b = vdupq_n_u8(0xF); let mone = vdupq_n_u8(1); let mtwo = vdupq_n_u8(2); for (x, y) in xs.iter().zip(ys.iter()) { let d = y.d * x.d.to_f32(); let dmin = y.d * x.dmin.to_f32(); let q8sums = vpaddq_s16( vld1q_s16(y.bsums.as_ptr()), vld1q_s16(y.bsums.as_ptr().add(8)), ); LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]); utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4); let uaux = utmp[1] & KMASK1; utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4); utmp[2] = uaux; utmp[0] &= KMASK1; let mins8 = vld1_u8((utmp.as_ptr() as *const u8).add(8)); let mins = vreinterpretq_s16_u16(vmovl_u8(mins8)); let prod = vaddq_s32( vmull_s16(vget_low_s16(q8sums), vget_low_s16(mins)), vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)), ); let sumi_mins = vaddvq_s32(prod); let mut scales = utmp.as_ptr() as *const u8; let mut q5 = x.qs.as_ptr(); let mut q8 = y.qs.as_ptr(); let mut qhbits = vld1q_u8_x2(x.qh.as_ptr()); let mut sumi = 0i32; for _j in 0..QK_K / 64 { let q5bits = vld1q_u8_x2(q5); q5 = q5.add(32); let q8bytes = vld1q_s8_x4(q8); q8 = q8.add(64); let q5h_0 = vshlq_n_u8(vandq_u8(mone, qhbits.0), 4); let q5h_1 = vshlq_n_u8(vandq_u8(mone, qhbits.1), 4); let q5h_2 = vshlq_n_u8(vandq_u8(mtwo, qhbits.0), 3); let q5h_3 = vshlq_n_u8(vandq_u8(mtwo, qhbits.1), 3); qhbits.0 = vshrq_n_u8(qhbits.0, 2); qhbits.1 = vshrq_n_u8(qhbits.1, 2); let q5bytes_0 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.0, m4b), q5h_0)); let q5bytes_1 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.1, m4b), q5h_1)); let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2)); let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3)); let p0 = vdotq_s32(q5bytes_0, q8bytes.0); let p1 = vdotq_s32(q5bytes_1, q8bytes.1); sumi += vaddvq_s32(vaddq_s32(p0, p1)) * *scales as i32; scales = scales.add(1); let p2 = vdotq_s32(q5bytes_2, q8bytes.2); let p3 = vdotq_s32(q5bytes_3, q8bytes.3); sumi += vaddvq_s32(vaddq_s32(p2, p3)) * *scales as i32; scales = scales.add(1); } sumf += d * sumi as f32 - dmin * sumi_mins as f32; } } Ok(sumf) } #[inline(always)] pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result { if n % QK_K != 0 { crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}") } let mut sumf = 0f32; let mut utmp = [0u32; 4]; let mut scales = [0u8; 16]; const KMASK1: u32 = 0x3f3f3f3f; const KMASK2: u32 = 0x0f0f0f0f; const KMASK3: u32 = 0x03030303; unsafe { let m4b = vdupq_n_u8(0xF); for (x, y) in xs.iter().zip(ys.iter()) { let d = y.d * x.d.to_f32(); let dmin = y.d * x.dmin.to_f32(); let q8sums = vpaddq_s16( vld1q_s16(y.bsums.as_ptr()), vld1q_s16(y.bsums.as_ptr().add(8)), ); LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]); let mins8 = vld1_u32( [ utmp[1] & KMASK1, ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4), ] .as_ptr(), ); utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4); utmp[0] &= KMASK1; let mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8))); let prod = vaddq_s32( vmull_s16(vget_low_s16(q8sums), vget_low_s16(mins)), vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)), ); sumf -= dmin * vaddvq_s32(prod) as f32; LittleEndian::write_u32_into(&utmp, &mut scales); let mut q4 = x.qs.as_ptr(); let mut q8 = y.qs.as_ptr(); let mut sumi1 = 0i32; let mut sumi2 = 0i32; for j in 0..QK_K / 64 { let q4bits = vld1q_u8_x2(q4); q4 = q4.add(32); let q8bytes = vld1q_s8_x2(q8); q8 = q8.add(32); let q4bytes = int8x16x2_t( vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)), vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)), ); let p0 = vdotq_s32(q4bytes.0, q8bytes.0); let p1 = vdotq_s32(q4bytes.1, q8bytes.1); sumi1 += vaddvq_s32(vaddq_s32(p0, p1)) * scales[2 * j] as i32; let q8bytes = vld1q_s8_x2(q8); q8 = q8.add(32); let q4bytes = int8x16x2_t( vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)), vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)), ); let p2 = vdotq_s32(q4bytes.0, q8bytes.0); let p3 = vdotq_s32(q4bytes.1, q8bytes.1); sumi2 += vaddvq_s32(vaddq_s32(p2, p3)) * scales[2 * j + 1] as i32; } sumf += d * (sumi1 + sumi2) as f32; } } Ok(sumf) } #[inline(always)] pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Result { if n % QK_K != 0 { crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}") } let mut sumf = 0f32; let mut utmp = [0u32; 4]; let mut aux = [0u32; 3]; const KMASK1: u32 = 0x03030303; const KMASK2: u32 = 0x0f0f0f0f; unsafe { let m3b = vdupq_n_u8(0x3); let m0 = vdupq_n_u8(1); let m1 = vshlq_n_u8(m0, 1); let m2 = vshlq_n_u8(m0, 2); let m3 = vshlq_n_u8(m0, 3); for (x, y) in xs.iter().zip(ys.iter()) { let d = y.d * x.d.to_f32(); let mut q3 = x.qs.as_ptr(); let qh = x.hmask.as_ptr(); let mut q8 = y.qs.as_ptr(); let mut qhbits = vld1q_u8_x2(qh); let mut isum = 0i32; // Set up scales LittleEndian::read_u32_into(&x.scales, &mut aux); utmp[3] = ((aux[1] >> 4) & KMASK2) | (((aux[2] >> 6) & KMASK1) << 4); utmp[2] = ((aux[0] >> 4) & KMASK2) | (((aux[2] >> 4) & KMASK1) << 4); utmp[1] = (aux[1] & KMASK2) | (((aux[2] >> 2) & KMASK1) << 4); utmp[0] = (aux[0] & KMASK2) | ((aux[2] & KMASK1) << 4); let mut scale = utmp.as_mut_ptr() as *mut i8; for j in 0..16 { *scale.add(j) -= 32i8 } for j in 0..QK_K / 128 { let q3bits = vld1q_u8_x2(q3); q3 = q3.add(32); let q8bytes_1 = vld1q_s8_x4(q8); q8 = q8.add(64); let q8bytes_2 = vld1q_s8_x4(q8); q8 = q8.add(64); let q3h_0 = vshlq_n_u8(vbicq_u8(m0, qhbits.0), 2); let q3h_1 = vshlq_n_u8(vbicq_u8(m0, qhbits.1), 2); let q3h_2 = vshlq_n_u8(vbicq_u8(m1, qhbits.0), 1); let q3h_3 = vshlq_n_u8(vbicq_u8(m1, qhbits.1), 1); let q3bytes_0 = vsubq_s8( vreinterpretq_s8_u8(vandq_u8(q3bits.0, m3b)), vreinterpretq_s8_u8(q3h_0), ); let q3bytes_1 = vsubq_s8( vreinterpretq_s8_u8(vandq_u8(q3bits.1, m3b)), vreinterpretq_s8_u8(q3h_1), ); let q3bytes_2 = vsubq_s8( vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 2), m3b)), vreinterpretq_s8_u8(q3h_2), ); let q3bytes_3 = vsubq_s8( vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 2), m3b)), vreinterpretq_s8_u8(q3h_3), ); let p0 = vdotq_s32(q3bytes_0, q8bytes_1.0); let p1 = vdotq_s32(q3bytes_1, q8bytes_1.1); let p2 = vdotq_s32(q3bytes_2, q8bytes_1.2); let p3 = vdotq_s32(q3bytes_3, q8bytes_1.3); isum += vaddvq_s32(p0) * *scale as i32 + vaddvq_s32(p1) * *scale.add(1) as i32 + vaddvq_s32(p2) * *scale.add(2) as i32 + vaddvq_s32(p3) * *scale.add(3) as i32; scale = scale.add(4); let q3h_0 = vbicq_u8(m2, qhbits.0); let q3h_1 = vbicq_u8(m2, qhbits.1); let q3h_2 = vshrq_n_u8(vbicq_u8(m3, qhbits.0), 1); let q3h_3 = vshrq_n_u8(vbicq_u8(m3, qhbits.1), 1); let q3bytes_0 = vsubq_s8( vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 4), m3b)), vreinterpretq_s8_u8(q3h_0), ); let q3bytes_1 = vsubq_s8( vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 4), m3b)), vreinterpretq_s8_u8(q3h_1), ); let q3bytes_2 = vsubq_s8( vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.0, 6), m3b)), vreinterpretq_s8_u8(q3h_2), ); let q3bytes_3 = vsubq_s8( vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.1, 6), m3b)), vreinterpretq_s8_u8(q3h_3), ); let p0 = vdotq_s32(q3bytes_0, q8bytes_2.0); let p1 = vdotq_s32(q3bytes_1, q8bytes_2.1); let p2 = vdotq_s32(q3bytes_2, q8bytes_2.2); let p3 = vdotq_s32(q3bytes_3, q8bytes_2.3); isum += vaddvq_s32(p0) * *scale as i32 + vaddvq_s32(p1) * *scale.add(1) as i32 + vaddvq_s32(p2) * *scale.add(2) as i32 + vaddvq_s32(p3) * *scale.add(3) as i32; scale = scale.add(4); if j == 0 { qhbits.0 = vshrq_n_u8(qhbits.0, 4); qhbits.1 = vshrq_n_u8(qhbits.1, 4); } } sumf += d * isum as f32; } } Ok(sumf) } #[inline(always)] pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result { if n % QK_K != 0 { crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}") } let mut sumf = 0f32; let mut aux = [0u8; 16]; unsafe { let m3 = vdupq_n_u8(0x3); let m4 = vdupq_n_u8(0xF); for (x, y) in xs.iter().zip(ys.iter()) { let d = y.d * x.d.to_f32(); let dmin = -y.d * x.dmin.to_f32(); let mut q2 = x.qs.as_ptr(); let mut q8 = y.qs.as_ptr(); let sc = x.scales.as_ptr(); let mins_and_scales = vld1q_u8(sc); let scales = vandq_u8(mins_and_scales, m4); vst1q_u8(aux.as_mut_ptr(), scales); let mins = vshrq_n_u8(mins_and_scales, 4); let q8sums = vld1q_s16_x2(y.bsums.as_ptr()); let mins16 = int16x8x2_t( vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins))), ); let s0 = vaddq_s32( vmull_s16(vget_low_s16(mins16.0), vget_low_s16(q8sums.0)), vmull_s16(vget_high_s16(mins16.0), vget_high_s16(q8sums.0)), ); let s1 = vaddq_s32( vmull_s16(vget_low_s16(mins16.1), vget_low_s16(q8sums.1)), vmull_s16(vget_high_s16(mins16.1), vget_high_s16(q8sums.1)), ); sumf += dmin * vaddvq_s32(vaddq_s32(s0, s1)) as f32; let mut isum = 0i32; let mut is = 0usize; // TODO: dotprod for _j in 0..QK_K / 128 { let q2bits = vld1q_u8_x2(q2); q2 = q2.add(32); let q8bytes = vld1q_s8_x2(q8); q8 = q8.add(32); let mut q2bytes = int8x16x2_t( vreinterpretq_s8_u8(vandq_u8(q2bits.0, m3)), vreinterpretq_s8_u8(vandq_u8(q2bits.1, m3)), ); isum += multiply_accum_with_scale(&aux, is, 0, q2bytes, q8bytes); let q8bytes = vld1q_s8_x2(q8); q8 = q8.add(32); q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 2), m3)); q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 2), m3)); isum += multiply_accum_with_scale(&aux, is, 2, q2bytes, q8bytes); let q8bytes = vld1q_s8_x2(q8); q8 = q8.add(32); q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 4), m3)); q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 4), m3)); isum += multiply_accum_with_scale(&aux, is, 4, q2bytes, q8bytes); let q8bytes = vld1q_s8_x2(q8); q8 = q8.add(32); q2bytes.0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.0, 6), m3)); q2bytes.1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.1, 6), m3)); isum += multiply_accum_with_scale(&aux, is, 6, q2bytes, q8bytes); is += 8; } sumf += d * isum as f32; } } Ok(sumf) } #[inline(always)] unsafe fn multiply_accum_with_scale( aux: &[u8; 16], is: usize, index: usize, q2bytes: int8x16x2_t, q8bytes: int8x16x2_t, ) -> i32 { let p1 = vdotq_s32(q2bytes.0, q8bytes.0); let p2 = vdotq_s32(q2bytes.1, q8bytes.1); vaddvq_s32(p1) * aux[is + index] as i32 + vaddvq_s32(p2) * aux[is + 1 + index] as i32 }