summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/quantized/neon.rs26
-rw-r--r--candle-core/src/quantized/simd128.rs4
2 files changed, 2 insertions, 28 deletions
diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs
index fd4c1388..51bd3e66 100644
--- a/candle-core/src/quantized/neon.rs
+++ b/candle-core/src/quantized/neon.rs
@@ -94,28 +94,18 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
}
let nb = n / QK8_0;
- if nb % 2 != 0 {
- crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even")
- }
unsafe {
let mut sumv0 = vdupq_n_f32(0.0f32);
- let mut sumv1 = vdupq_n_f32(0.0f32);
- for i in (0..nb).step_by(2) {
+ for i in 0..nb {
let x0 = &xs[i];
- let x1 = &xs[i + 1];
let y0 = &ys[i];
- let y1 = &ys[i + 1];
let x0_0 = vld1q_s8(x0.qs.as_ptr());
let x0_1 = vld1q_s8(x0.qs.as_ptr().add(16));
- let x1_0 = vld1q_s8(x1.qs.as_ptr());
- let x1_1 = vld1q_s8(x1.qs.as_ptr().add(16));
// load y
let y0_0 = vld1q_s8(y0.qs.as_ptr());
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
- let y1_0 = vld1q_s8(y1.qs.as_ptr());
- let y1_1 = vld1q_s8(y1.qs.as_ptr().add(16));
// TODO dotprod once this is the intrinsics are.
let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0));
@@ -123,28 +113,16 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1));
let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
- let p1_0 = vmull_s8(vget_low_s8(x1_0), vget_low_s8(y1_0));
- let p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
- let p1_2 = vmull_s8(vget_low_s8(x1_1), vget_low_s8(y1_1));
- let p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
-
let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
- let p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
- let p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
sumv0 = vmlaq_n_f32(
sumv0,
vcvtq_f32_s32(vaddq_s32(p0, p1)),
x0.d.to_f32() * y0.d.to_f32(),
);
- sumv1 = vmlaq_n_f32(
- sumv1,
- vcvtq_f32_s32(vaddq_s32(p2, p3)),
- x1.d.to_f32() * y1.d.to_f32(),
- );
}
- Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1))
+ Ok(vaddvq_f32(sumv0))
}
}
diff --git a/candle-core/src/quantized/simd128.rs b/candle-core/src/quantized/simd128.rs
index 687399c2..f256fdc2 100644
--- a/candle-core/src/quantized/simd128.rs
+++ b/candle-core/src/quantized/simd128.rs
@@ -61,10 +61,6 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
}
- let nb = n / QK8_0;
- if nb % 2 != 0 {
- crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even")
- }
unsafe {
let mut acc = f32x4_splat(0.0f32);
for (x, y) in xs.iter().zip(ys.iter()) {