diff options
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/quantized/neon.rs | 26 | ||||
-rw-r--r-- | candle-core/src/quantized/simd128.rs | 4 |
2 files changed, 2 insertions, 28 deletions
diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs index fd4c1388..51bd3e66 100644 --- a/candle-core/src/quantized/neon.rs +++ b/candle-core/src/quantized/neon.rs @@ -94,28 +94,18 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}") } let nb = n / QK8_0; - if nb % 2 != 0 { - crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even") - } unsafe { let mut sumv0 = vdupq_n_f32(0.0f32); - let mut sumv1 = vdupq_n_f32(0.0f32); - for i in (0..nb).step_by(2) { + for i in 0..nb { let x0 = &xs[i]; - let x1 = &xs[i + 1]; let y0 = &ys[i]; - let y1 = &ys[i + 1]; let x0_0 = vld1q_s8(x0.qs.as_ptr()); let x0_1 = vld1q_s8(x0.qs.as_ptr().add(16)); - let x1_0 = vld1q_s8(x1.qs.as_ptr()); - let x1_1 = vld1q_s8(x1.qs.as_ptr().add(16)); // load y let y0_0 = vld1q_s8(y0.qs.as_ptr()); let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16)); - let y1_0 = vld1q_s8(y1.qs.as_ptr()); - let y1_1 = vld1q_s8(y1.qs.as_ptr().add(16)); // TODO dotprod once this is the intrinsics are. let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0)); @@ -123,28 +113,16 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1)); let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1)); - let p1_0 = vmull_s8(vget_low_s8(x1_0), vget_low_s8(y1_0)); - let p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0)); - let p1_2 = vmull_s8(vget_low_s8(x1_1), vget_low_s8(y1_1)); - let p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1)); - let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1)); let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3)); - let p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1)); - let p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3)); sumv0 = vmlaq_n_f32( sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), x0.d.to_f32() * y0.d.to_f32(), ); - sumv1 = vmlaq_n_f32( - sumv1, - vcvtq_f32_s32(vaddq_s32(p2, p3)), - x1.d.to_f32() * y1.d.to_f32(), - ); } - Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1)) + Ok(vaddvq_f32(sumv0)) } } diff --git a/candle-core/src/quantized/simd128.rs b/candle-core/src/quantized/simd128.rs index 687399c2..f256fdc2 100644 --- a/candle-core/src/quantized/simd128.rs +++ b/candle-core/src/quantized/simd128.rs @@ -61,10 +61,6 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> if n % QK8_0 != 0 { crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}") } - let nb = n / QK8_0; - if nb % 2 != 0 { - crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even") - } unsafe { let mut acc = f32x4_splat(0.0f32); for (x, y) in xs.iter().zip(ys.iter()) { |