diff options
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/metal_backend.rs | 58 | ||||
-rw-r--r-- | candle-core/src/pickle.rs | 14 | ||||
-rw-r--r-- | candle-core/src/quantized/k_quants.rs | 8 | ||||
-rw-r--r-- | candle-core/src/quantized/neon.rs | 208 |
4 files changed, 127 insertions, 161 deletions
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 24beeb7a..48250233 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -592,14 +592,26 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::F32) => "cast_u32_f32", (DType::U32, DType::U8) => "cast_u32_u8", (DType::U32, DType::I64) => "cast_u32_i64", + (DType::U32, DType::BF16) => "cast_u32_bf16", + (DType::U8, DType::U32) => "cast_u8_u32", (DType::U8, DType::F32) => "cast_u8_f32", (DType::U8, DType::I64) => "cast_u8_i64", + (DType::U8, DType::BF16) => "cast_u8_bf16", + (DType::F32, DType::F16) => "cast_f32_f16", - (DType::F16, DType::F32) => "cast_f16_f32", - (DType::I64, DType::F32) => "cast_i64_f32", (DType::F32, DType::BF16) => "cast_f32_bf16", + + (DType::I64, DType::F32) => "cast_i64_f32", + + (DType::F16, DType::BF16) => "cast_f16_bf16", + (DType::F16, DType::F32) => "cast_f16_f32", + + (DType::BF16, DType::U8) => "cast_bf16_u8", + (DType::BF16, DType::U32) => "cast_bf16_u32", + (DType::BF16, DType::F16) => "cast_bf16_f16", (DType::BF16, DType::F32) => "cast_bf16_f32", + (left, right) => { crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented") } @@ -677,6 +689,7 @@ impl BackendStorage for MetalStorage { ("uround", DType::F32) => contiguous::round::FLOAT, ("urecip", DType::F32) => contiguous::recip::FLOAT, ("utanh", DType::F32) => contiguous::tanh::FLOAT, + ("urelu", DType::F32) => contiguous::relu::FLOAT, ("ucos", DType::F16) => contiguous::cos::HALF, ("usin", DType::F16) => contiguous::sin::HALF, ("usqr", DType::F16) => contiguous::sqr::HALF, @@ -693,6 +706,7 @@ impl BackendStorage for MetalStorage { ("uround", DType::F16) => contiguous::round::HALF, ("urecip", DType::F16) => contiguous::recip::HALF, ("utanh", DType::F16) => contiguous::tanh::HALF, + ("urelu", DType::F16) => contiguous::relu::HALF, (name, dtype) => { crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented") } @@ -723,6 +737,7 @@ impl BackendStorage for MetalStorage { ("uabs", DType::F32) => strided::abs::FLOAT, ("uceil", DType::F32) => strided::ceil::FLOAT, ("ufloor", DType::F32) => strided::floor::FLOAT, + ("urelu", DType::F32) => strided::relu::FLOAT, ("uround", DType::F32) => strided::round::FLOAT, ("ucos", DType::F16) => strided::cos::HALF, ("usin", DType::F16) => strided::sin::HALF, @@ -737,6 +752,7 @@ impl BackendStorage for MetalStorage { ("uabs", DType::F16) => strided::abs::HALF, ("uceil", DType::F16) => strided::ceil::HALF, ("ufloor", DType::F16) => strided::floor::HALF, + ("urelu", DType::F16) => strided::relu::HALF, ("uround", DType::F16) => strided::round::HALF, (name, dtype) => { crate::bail!("Metal strided unary {name} {dtype:?} not implemented") @@ -1129,8 +1145,12 @@ impl BackendStorage for MetalStorage { let device = self.device(); let buffer = device.new_buffer(dst_el, dtype, "index_select")?; let name = match (ids.dtype, self.dtype) { + (DType::U8, DType::BF16) => "is_u8_bf16", + (DType::U32, DType::F32) => "is_u32_f32", (DType::U32, DType::F16) => "is_u32_f16", + (DType::U32, DType::BF16) => "is_u32_bf16", + (left, right) => { crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented") } @@ -1320,6 +1340,7 @@ impl MetalStorage { ("lt", DType::F32) => (contiguous::lt::FLOAT, DType::U8), ("ge", DType::F32) => (contiguous::ge::FLOAT, DType::U8), ("gt", DType::F32) => (contiguous::gt::FLOAT, DType::U8), + ("add", DType::F16) => (contiguous::add::HALF, self.dtype), ("sub", DType::F16) => (contiguous::sub::HALF, self.dtype), ("mul", DType::F16) => (contiguous::mul::HALF, self.dtype), @@ -1330,6 +1351,18 @@ impl MetalStorage { ("lt", DType::F16) => (contiguous::lt::HALF, DType::U8), ("ge", DType::F16) => (contiguous::ge::HALF, DType::U8), ("gt", DType::F16) => (contiguous::gt::HALF, DType::U8), + + ("add", DType::BF16) => (contiguous::add::BFLOAT, self.dtype), + ("sub", DType::BF16) => (contiguous::sub::BFLOAT, self.dtype), + ("mul", DType::BF16) => (contiguous::mul::BFLOAT, self.dtype), + ("div", DType::BF16) => (contiguous::div::BFLOAT, self.dtype), + ("eq", DType::BF16) => (contiguous::eq::BFLOAT, DType::U8), + ("ne", DType::BF16) => (contiguous::ne::BFLOAT, DType::U8), + ("le", DType::BF16) => (contiguous::le::BFLOAT, DType::U8), + ("lt", DType::BF16) => (contiguous::lt::BFLOAT, DType::U8), + ("ge", DType::BF16) => (contiguous::ge::BFLOAT, DType::U8), + ("gt", DType::BF16) => (contiguous::gt::BFLOAT, DType::U8), + ("add", DType::I64) => (contiguous::add::I64, self.dtype), ("sub", DType::I64) => (contiguous::sub::I64, self.dtype), ("mul", DType::I64) => (contiguous::mul::I64, self.dtype), @@ -1340,6 +1373,7 @@ impl MetalStorage { ("lt", DType::I64) => (contiguous::lt::I64, DType::U8), ("ge", DType::I64) => (contiguous::ge::I64, DType::U8), ("gt", DType::I64) => (contiguous::gt::I64, DType::U8), + ("add", DType::U32) => (contiguous::add::U32, self.dtype), ("sub", DType::U32) => (contiguous::sub::U32, self.dtype), ("mul", DType::U32) => (contiguous::mul::U32, self.dtype), @@ -1350,6 +1384,7 @@ impl MetalStorage { ("lt", DType::U32) => (contiguous::lt::U32, DType::U8), ("ge", DType::U32) => (contiguous::ge::U32, DType::U8), ("gt", DType::U32) => (contiguous::gt::U32, DType::U8), + ("add", DType::U8) => (contiguous::add::U8, self.dtype), ("sub", DType::U8) => (contiguous::sub::U8, self.dtype), ("mul", DType::U8) => (contiguous::mul::U8, self.dtype), @@ -1360,6 +1395,7 @@ impl MetalStorage { ("lt", DType::U8) => (contiguous::lt::U8, DType::U8), ("ge", DType::U8) => (contiguous::ge::U8, DType::U8), ("gt", DType::U8) => (contiguous::gt::U8, DType::U8), + (name, dtype) => { crate::bail!("Metal contiguous binary {name} {dtype:?} not implemented") } @@ -1393,6 +1429,7 @@ impl MetalStorage { ("lt", DType::F32) => (strided::lt::FLOAT, DType::U8), ("ge", DType::F32) => (strided::ge::FLOAT, DType::U8), ("gt", DType::F32) => (strided::gt::FLOAT, DType::U8), + ("badd", DType::F16) => (strided::add::HALF, self.dtype), ("bsub", DType::F16) => (strided::sub::HALF, self.dtype), ("bmul", DType::F16) => (strided::mul::HALF, self.dtype), @@ -1405,6 +1442,20 @@ impl MetalStorage { ("lt", DType::F16) => (strided::lt::HALF, DType::U8), ("ge", DType::F16) => (strided::ge::HALF, DType::U8), ("gt", DType::F16) => (strided::gt::HALF, DType::U8), + + ("badd", DType::BF16) => (strided::add::BFLOAT, self.dtype), + ("bsub", DType::BF16) => (strided::sub::BFLOAT, self.dtype), + ("bmul", DType::BF16) => (strided::mul::BFLOAT, self.dtype), + ("bdiv", DType::BF16) => (strided::div::BFLOAT, self.dtype), + ("bminimum", DType::BF16) => (strided::min::BFLOAT, self.dtype), + ("bmaximum", DType::BF16) => (strided::max::BFLOAT, self.dtype), + ("eq", DType::BF16) => (strided::eq::BFLOAT, DType::U8), + ("ne", DType::BF16) => (strided::ne::BFLOAT, DType::U8), + ("le", DType::BF16) => (strided::le::BFLOAT, DType::U8), + ("lt", DType::BF16) => (strided::lt::BFLOAT, DType::U8), + ("ge", DType::BF16) => (strided::ge::BFLOAT, DType::U8), + ("gt", DType::BF16) => (strided::gt::BFLOAT, DType::U8), + ("badd", DType::I64) => (strided::add::I64, self.dtype), ("bsub", DType::I64) => (strided::sub::I64, self.dtype), ("bmul", DType::I64) => (strided::mul::I64, self.dtype), @@ -1417,6 +1468,7 @@ impl MetalStorage { ("lt", DType::I64) => (strided::lt::I64, DType::U8), ("ge", DType::I64) => (strided::ge::I64, DType::U8), ("gt", DType::I64) => (strided::gt::I64, DType::U8), + ("badd", DType::U32) => (strided::add::U32, self.dtype), ("bsub", DType::U32) => (strided::sub::U32, self.dtype), ("bmul", DType::U32) => (strided::mul::U32, self.dtype), @@ -1429,6 +1481,7 @@ impl MetalStorage { ("lt", DType::U32) => (strided::lt::U32, DType::U8), ("ge", DType::U32) => (strided::ge::U32, DType::U8), ("gt", DType::U32) => (strided::gt::U32, DType::U8), + ("badd", DType::U8) => (strided::add::U8, self.dtype), ("bsub", DType::U8) => (strided::sub::U8, self.dtype), ("bmul", DType::U8) => (strided::mul::U8, self.dtype), @@ -1441,6 +1494,7 @@ impl MetalStorage { ("lt", DType::U8) => (strided::lt::U8, DType::U8), ("ge", DType::U8) => (strided::ge::U8, DType::U8), ("gt", DType::U8) => (strided::gt::U8, DType::U8), + (name, dtype) => { crate::bail!("Metal strided binary {name} {dtype:?} not implemented") } diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs index 25640d1a..276b30e3 100644 --- a/candle-core/src/pickle.rs +++ b/candle-core/src/pickle.rs @@ -703,6 +703,7 @@ impl PthTensors { } pub fn get(&self, name: &str) -> Result<Option<Tensor>> { + use std::io::Read; let tensor_info = match self.tensor_infos.get(name) { None => return Ok(None), Some(tensor_info) => tensor_info, @@ -712,14 +713,21 @@ impl PthTensors { let mut zip = zip::ZipArchive::new(zip_reader)?; let mut reader = zip.by_name(&tensor_info.path)?; - // Reading the data is a bit tricky as it can be strided, use an offset, etc. - // For now only support the basic case. - if tensor_info.layout.start_offset() != 0 || !tensor_info.layout.is_contiguous() { + // Reading the data is a bit tricky as it can be strided, for now only support the basic + // case. + if !tensor_info.layout.is_contiguous() { crate::bail!( "cannot retrieve non-contiguous tensors {:?}", tensor_info.layout ) } + let start_offset = tensor_info.layout.start_offset(); + if start_offset > 0 { + std::io::copy( + &mut reader.by_ref().take(start_offset as u64), + &mut std::io::sink(), + )?; + } let tensor = Tensor::from_reader( tensor_info.layout.shape().clone(), tensor_info.dtype, diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index d16289e6..6210ac1e 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -1545,13 +1545,13 @@ impl GgmlType for BlockQ5K { let d2 = d * sc as f32; let m2 = min * m as f32; for (ql, qh) in ql.iter().zip(qh) { - let to_add = if qh & u1 != 0 { 16 } else { 1 }; - y[ys_index] = d1 * ((ql & 0xF) + to_add) as f32 - m1; + let to_add = if qh & u1 != 0 { 16f32 } else { 0f32 }; + y[ys_index] = d1 * ((ql & 0xF) as f32 + to_add) - m1; ys_index += 1; } for (ql, qh) in ql.iter().zip(qh) { - let to_add = if qh & u2 != 0 { 16 } else { 1 }; - y[ys_index] = d2 * ((ql >> 4) + to_add) as f32 - m2; + let to_add = if qh & u2 != 0 { 16f32 } else { 0f32 }; + y[ys_index] = d2 * ((ql >> 4) as f32 + to_add) - m2; ys_index += 1; } is += 2; diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs index 3cb56229..c4d5d6f4 100644 --- a/candle-core/src/quantized/neon.rs +++ b/candle-core/src/quantized/neon.rs @@ -13,6 +13,14 @@ use core::arch::arm::*; use core::arch::aarch64::*; #[inline(always)] +unsafe fn vdotq_s32(a: int8x16_t, b: int8x16_t) -> int32x4_t { + // TODO: dotprod + let p0 = vmull_s8(vget_low_s8(a), vget_low_s8(b)); + let p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b)); + vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)) +} + +#[inline(always)] pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> { let qk = QK8_0; let nb = n / qk; @@ -43,15 +51,8 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> let v1_0l = vld1q_s8(y0.qs.as_ptr()); let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16)); - // TODO: Support dotprod when it's available outside of nightly. - let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l)); - let pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l)); - let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h)); - let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h)); - - let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); - let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - + let pl0 = vdotq_s32(v0_0ls, v1_0l); + let ph0 = vdotq_s32(v0_0hs, v1_0h); sumv0 = vmlaq_n_f32( sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), @@ -82,14 +83,8 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> let y0_0 = vld1q_s8(y0.qs.as_ptr()); let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16)); - // TODO dotprod once this is the intrinsics are. - let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0)); - let p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0)); - let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1)); - let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1)); - - let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1)); - let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3)); + let p0 = vdotq_s32(x0_0, y0_0); + let p1 = vdotq_s32(x0_1, y0_1); sumv0 = vmlaq_n_f32( sumv0, @@ -118,10 +113,7 @@ pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Res for i in (0..QK_K).step_by(16) { let xs = vld1q_s8(xs.add(i)); let ys = vld1q_s8(ys.add(i)); - let xy_lo = vmull_s8(vget_low_s8(xs), vget_low_s8(ys)); - let xy_up = vmull_s8(vget_high_s8(xs), vget_high_s8(ys)); - - let xy = vaddq_s32(vpaddlq_s16(xy_lo), vpaddlq_s16(xy_up)); + let xy = vdotq_s32(xs, ys); sum_i = vaddq_s32(sum_i, xy) } sumf += vaddvq_s32(sum_i) as f32 * scale @@ -191,30 +183,16 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2)); let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3)); - // TODO: dotprod - - let p0 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)), - vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)), - ); - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)), - vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)), - ); + let p0 = vdotq_s32(q6bytes_0, q8bytes.0); + let p1 = vdotq_s32(q6bytes_1, q8bytes.1); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); - isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1; + isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1; scale = scale.add(2); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)), - vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)), - ); - let p3 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)), - vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)), - ); + let p2 = vdotq_s32(q6bytes_2, q8bytes.2); + let p3 = vdotq_s32(q6bytes_3, q8bytes.3); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); - isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1; + isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1; scale = scale.add(2); let q8bytes = vld1q_s8_x4(q8); @@ -234,29 +212,16 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2)); let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3)); - // TODO: dotprod case. - let p0 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)), - vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)), - ); - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)), - vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)), - ); + let p0 = vdotq_s32(q6bytes_0, q8bytes.0); + let p1 = vdotq_s32(q6bytes_1, q8bytes.1); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); - isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1; + isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1; scale = scale.add(2); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)), - vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)), - ); - let p3 = vaddq_s16( - vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)), - vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)), - ); + let p2 = vdotq_s32(q6bytes_2, q8bytes.2); + let p3 = vdotq_s32(q6bytes_3, q8bytes.3); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); - isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1; + isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1; scale = scale.add(2); } sum += d_all * y.d * ((isum - 32 * isum_mins) as f32); @@ -333,28 +298,14 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2)); let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3)); - // TODO: dotprod - - let p0 = vaddq_s16( - vmull_s8(vget_low_s8(q5bytes_0), vget_low_s8(q8bytes.0)), - vmull_s8(vget_high_s8(q5bytes_0), vget_high_s8(q8bytes.0)), - ); - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q5bytes_1), vget_low_s8(q8bytes.1)), - vmull_s8(vget_high_s8(q5bytes_1), vget_high_s8(q8bytes.1)), - ); - sumi += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * *scales as i32; + let p0 = vdotq_s32(q5bytes_0, q8bytes.0); + let p1 = vdotq_s32(q5bytes_1, q8bytes.1); + sumi += vaddvq_s32(vaddq_s32(p0, p1)) * *scales as i32; scales = scales.add(1); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q5bytes_2), vget_low_s8(q8bytes.2)), - vmull_s8(vget_high_s8(q5bytes_2), vget_high_s8(q8bytes.2)), - ); - let p3 = vaddq_s16( - vmull_s8(vget_low_s8(q5bytes_3), vget_low_s8(q8bytes.3)), - vmull_s8(vget_high_s8(q5bytes_3), vget_high_s8(q8bytes.3)), - ); - sumi += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * *scales as i32; + let p2 = vdotq_s32(q5bytes_2, q8bytes.2); + let p3 = vdotq_s32(q5bytes_3, q8bytes.3); + sumi += vaddvq_s32(vaddq_s32(p2, p3)) * *scales as i32; scales = scales.add(1); } sumf += d * sumi as f32 - dmin * sumi_mins as f32; @@ -417,22 +368,15 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res for j in 0..QK_K / 64 { let q4bits = vld1q_u8_x2(q4); q4 = q4.add(32); - // TODO: dotprod let q8bytes = vld1q_s8_x2(q8); q8 = q8.add(32); let q4bytes = int8x16x2_t( vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)), vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)), ); - let p0 = vaddq_s16( - vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)), - vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)), - ); - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)), - vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)), - ); - sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * scales[2 * j] as i32; + let p0 = vdotq_s32(q4bytes.0, q8bytes.0); + let p1 = vdotq_s32(q4bytes.1, q8bytes.1); + sumi1 += vaddvq_s32(vaddq_s32(p0, p1)) * scales[2 * j] as i32; let q8bytes = vld1q_s8_x2(q8); q8 = q8.add(32); @@ -440,15 +384,9 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)), vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)), ); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)), - vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)), - ); - let p3 = vaddq_s16( - vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)), - vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)), - ); - sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * scales[2 * j + 1] as i32; + let p2 = vdotq_s32(q4bytes.0, q8bytes.0); + let p3 = vdotq_s32(q4bytes.1, q8bytes.1); + sumi2 += vaddvq_s32(vaddq_s32(p2, p3)) * scales[2 * j + 1] as i32; } sumf += d * (sumi1 + sumi2) as f32; } @@ -526,27 +464,14 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res vreinterpretq_s8_u8(q3h_3), ); - // TODO: dotprod - let p0 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_1.0)), - vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_1.0)), - ); - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_1.1)), - vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_1.1)), - ); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_1.2)), - vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_1.2)), - ); - let p3 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_1.3)), - vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_1.3)), - ); - isum += vaddvq_s16(p0) as i32 * *scale as i32 - + vaddvq_s16(p1) as i32 * *scale.add(1) as i32 - + vaddvq_s16(p2) as i32 * *scale.add(2) as i32 - + vaddvq_s16(p3) as i32 * *scale.add(3) as i32; + let p0 = vdotq_s32(q3bytes_0, q8bytes_1.0); + let p1 = vdotq_s32(q3bytes_1, q8bytes_1.1); + let p2 = vdotq_s32(q3bytes_2, q8bytes_1.2); + let p3 = vdotq_s32(q3bytes_3, q8bytes_1.3); + isum += vaddvq_s32(p0) * *scale as i32 + + vaddvq_s32(p1) * *scale.add(1) as i32 + + vaddvq_s32(p2) * *scale.add(2) as i32 + + vaddvq_s32(p3) * *scale.add(3) as i32; scale = scale.add(4); let q3h_0 = vbicq_u8(m2, qhbits.0); @@ -571,27 +496,14 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res vreinterpretq_s8_u8(q3h_3), ); - // TODO: dotprod - let p0 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_2.0)), - vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_2.0)), - ); - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_2.1)), - vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_2.1)), - ); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_2.2)), - vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_2.2)), - ); - let p3 = vaddq_s16( - vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_2.3)), - vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_2.3)), - ); - isum += vaddvq_s16(p0) as i32 * *scale as i32 - + vaddvq_s16(p1) as i32 * *scale.add(1) as i32 - + vaddvq_s16(p2) as i32 * *scale.add(2) as i32 - + vaddvq_s16(p3) as i32 * *scale.add(3) as i32; + let p0 = vdotq_s32(q3bytes_0, q8bytes_2.0); + let p1 = vdotq_s32(q3bytes_1, q8bytes_2.1); + let p2 = vdotq_s32(q3bytes_2, q8bytes_2.2); + let p3 = vdotq_s32(q3bytes_3, q8bytes_2.3); + isum += vaddvq_s32(p0) * *scale as i32 + + vaddvq_s32(p1) * *scale.add(1) as i32 + + vaddvq_s32(p2) * *scale.add(2) as i32 + + vaddvq_s32(p3) * *scale.add(3) as i32; scale = scale.add(4); if j == 0 { @@ -649,7 +561,6 @@ pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Res let mut is = 0usize; // TODO: dotprod - for _j in 0..QK_K / 128 { let q2bits = vld1q_u8_x2(q2); q2 = q2.add(32); @@ -696,14 +607,7 @@ unsafe fn multiply_accum_with_scale( q2bytes: int8x16x2_t, q8bytes: int8x16x2_t, ) -> i32 { - let p1 = vaddq_s16( - vmull_s8(vget_low_s8(q2bytes.0), vget_low_s8(q8bytes.0)), - vmull_s8(vget_high_s8(q2bytes.0), vget_high_s8(q8bytes.0)), - ); - let p2 = vaddq_s16( - vmull_s8(vget_low_s8(q2bytes.1), vget_low_s8(q8bytes.1)), - vmull_s8(vget_high_s8(q2bytes.1), vget_high_s8(q8bytes.1)), - ); - vaddvq_s16(p1) as i32 * aux[is + index] as i32 - + vaddvq_s16(p2) as i32 * aux[is + 1 + index] as i32 + let p1 = vdotq_s32(q2bytes.0, q8bytes.0); + let p2 = vdotq_s32(q2bytes.1, q8bytes.1); + vaddvq_s32(p1) * aux[is + index] as i32 + vaddvq_s32(p2) * aux[is + 1 + index] as i32 } |