summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/quantized/cuda.rs64
-rw-r--r--candle-core/tests/quantized_tests.rs36
-rw-r--r--candle-kernels/src/quantized.cu264
3 files changed, 335 insertions, 29 deletions
diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs
index 54b1da41..d6a61682 100644
--- a/candle-core/src/quantized/cuda.rs
+++ b/candle-core/src/quantized/cuda.rs
@@ -166,6 +166,7 @@ fn mul_mat_vec_via_q8_1(
dtype: GgmlDType,
ncols: usize,
nrows: usize,
+ b_size: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
@@ -174,14 +175,18 @@ fn mul_mat_vec_via_q8_1(
if data_elems < ncols * nrows {
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
}
- if y.len() != ncols {
+ if y.len() != ncols * b_size {
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
}
+ if b_size == 0 || b_size > 4 {
+ crate::bail!("only bsize between 1 and 4 are supported, got {b_size}")
+ }
// Start by quantizing y
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
- let y_size_in_bytes = ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
+ let y_size_in_bytes =
+ b_size * ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
- quantize_q8_1(y, &mut y_q8_1, ncols, 1, dev)?;
+ quantize_q8_1(y, &mut y_q8_1, ncols, b_size, dev)?;
let kernel_name = match dtype {
GgmlDType::Q4_0 => "mul_mat_vec_q4_0_q8_1_cuda",
@@ -196,10 +201,16 @@ fn mul_mat_vec_via_q8_1(
GgmlDType::Q6K => "mul_mat_vec_q6_K_q8_1_cuda",
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
- let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
- let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
+ let kernel_name = format!("{kernel_name}{b_size}");
+ let func = dev.get_or_load_func(&kernel_name, candle_kernels::QUANTIZED)?;
+ let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? };
+ let nblocks = if b_size == 1 {
+ nrows as u32
+ } else {
+ (nrows as u32 + 1) / 2
+ };
let cfg = cudarc::driver::LaunchConfig {
- grid_dim: (nrows as u32, 1, 1),
+ grid_dim: (nblocks, 1, 1),
block_dim: (WARP_SIZE as u32, 4, 1),
shared_mem_bytes: 0,
};
@@ -210,7 +221,7 @@ fn mul_mat_vec_via_q8_1(
&dst,
/* ncols_x */ ncols as i32,
/* nrows_x */ nrows as i32,
- /* nrows_y */ ncols as i32,
+ /* nrows_y */ ncols_padded as i32,
/* nrows_dst */ nrows as i32,
);
unsafe { func.launch(cfg, params) }.w()?;
@@ -384,7 +395,17 @@ impl QCudaStorage {
storage: &CudaStorage,
layout: &crate::Layout,
) -> Result<(CudaStorage, crate::Shape)> {
- if matches!(layout.shape().dims(), [1, 1, _] | [1, _]) {
+ let max_bm = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
+ 1
+ } else {
+ 4
+ };
+ let use_vec_kernel = match layout.shape().dims() {
+ [b, m, _k] => b * m <= max_bm,
+ [b, _k] => *b <= max_bm,
+ _ => false,
+ };
+ if use_vec_kernel {
self.dequantize_matmul_vec(self_shape, storage, layout)
} else {
self.dequantize_matmul(self_shape, storage, layout)
@@ -405,25 +426,31 @@ impl QCudaStorage {
Some((o1, o2)) => rhs.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "dmmv" }.bt())?,
};
- let (with_batch, k) = match rhs_l.shape().dims() {
- [1, 1, k] => (true, k),
- [1, k] => (false, k),
+ let (b_size, k) = match rhs_l.shape().dims() {
+ [b, m, k] => (b * m, *k),
+ [b, k] => (*b, *k),
_ => crate::bail!("unexpected rhs shape in dmmv {:?}", rhs_l.shape()),
};
- if ncols != *k {
+ if ncols != k {
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape())
}
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
dequantize_mul_mat_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
} else {
- mul_mat_vec_via_q8_1(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
- };
- let out_shape = if with_batch {
- vec![1, 1, nrows]
- } else {
- vec![1, nrows]
+ mul_mat_vec_via_q8_1(
+ &self.data,
+ &rhs,
+ self.dtype,
+ ncols,
+ nrows,
+ b_size,
+ self.device(),
+ )?
};
+ let mut out_shape = rhs_l.shape().dims().to_vec();
+ out_shape.pop();
+ out_shape.push(nrows);
Ok((out, out_shape.into()))
}
@@ -522,6 +549,7 @@ mod test {
/* dtype */ GgmlDType::Q4_0,
/* ncols */ ncols,
/* nrows */ 1,
+ /* b_size */ 1,
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs
index 223accc4..157f2f8d 100644
--- a/candle-core/tests/quantized_tests.rs
+++ b/candle-core/tests/quantized_tests.rs
@@ -170,12 +170,46 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
let res2 = matmul.forward(&lhs2)?;
let res2 = res2.i(1)?;
let diff = (res - res2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
- assert_eq!(diff, 0.);
+ if device.is_cuda() {
+ assert!(diff < 0.1);
+ } else {
+ assert_eq!(diff, 0.);
+ }
+ Ok(())
+}
+
+fn qmm_batch(dev: &Device) -> Result<()> {
+ let (lhs, rhs, _mm) = get_random_tensors(2, 256, 6, dev)?;
+ let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q2K)?;
+ let rhs = quantized::QMatMul::from_qtensor(rhs)?;
+ let mm = rhs.forward(&lhs)?;
+ assert_eq!(mm.shape().dims(), [2, 6]);
+ let lhs2 = Tensor::cat(&[&lhs, &lhs], 0)?;
+ let mm2 = rhs.forward(&lhs2)?;
+ assert_eq!(mm2.shape().dims(), [4, 6]);
+ let diff2 = (mm2.i(2..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
+ assert_eq!(diff2, 0.0);
+ let lhs3 = Tensor::cat(&[&lhs2, &lhs], 0)?;
+ let mm3 = rhs.forward(&lhs3)?;
+ assert_eq!(mm3.shape().dims(), [6, 6]);
+ let diff3 = (mm3.i(2..4)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
+ if dev.is_cuda() {
+ assert!(diff3 < 1e-4)
+ } else {
+ assert_eq!(diff3, 0.0)
+ };
+ let diff3 = (mm3.i(4..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
+ if dev.is_cuda() {
+ assert!(diff3 < 1e-4)
+ } else {
+ assert_eq!(diff3, 0.0)
+ };
Ok(())
}
test_device!(quantized_matmul, qmm_cpu, qmm_cuda, qmm_metal);
test_device!(quantized_matmul_neg, qmm_n_cpu, qmm_n_cuda, qmm_n_metal);
+test_device!(qmm_batch, qmm_b_cpu, qmm_b_cuda, qmm_b_metal);
fn quantize_q4_0(device: &Device) -> Result<()> {
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
diff --git a/candle-kernels/src/quantized.cu b/candle-kernels/src/quantized.cu
index fa38f325..7e3e7b4c 100644
--- a/candle-kernels/src/quantized.cu
+++ b/candle-kernels/src/quantized.cu
@@ -2648,7 +2648,8 @@ static __device__ void mul_mat_vec_q(
}
}
-extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda(
+// batch size = 1
+extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda1(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
@@ -2656,7 +2657,7 @@ extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda(
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
-extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda(
+extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda1(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
@@ -2664,7 +2665,7 @@ extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda(
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
-extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda(
+extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda1(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
@@ -2672,7 +2673,7 @@ extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda(
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
-extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda(
+extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda1(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
@@ -2680,7 +2681,7 @@ extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda(
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
-extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda(
+extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda1(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
@@ -2688,7 +2689,7 @@ extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda(
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
-extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda(
+extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda1(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
@@ -2696,7 +2697,7 @@ extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda(
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
-extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda(
+extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda1(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
@@ -2704,7 +2705,7 @@ extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda(
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
-extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda(
+extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda1(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
@@ -2712,7 +2713,7 @@ extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda(
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
-extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda(
+extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda1(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
@@ -2720,7 +2721,7 @@ extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda(
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
-extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda(
+extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda1(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
@@ -2728,6 +2729,249 @@ extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda(
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
}
+// batch size = 2
+extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda2(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+
+ mul_mat_vec_q<2, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
+ (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+}
+
+extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda2(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+
+ mul_mat_vec_q<2, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
+ (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+}
+
+extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda2(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+
+ mul_mat_vec_q<2, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
+ (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+}
+
+extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda2(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+
+ mul_mat_vec_q<2, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
+ (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+}
+
+extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda2(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+
+ mul_mat_vec_q<2, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
+ (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+}
+
+extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda2(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+
+ mul_mat_vec_q<2, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
+ (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+}
+
+extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda2(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+
+ mul_mat_vec_q<2, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
+ (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+}
+
+extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda2(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+
+ mul_mat_vec_q<2, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
+ (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+}
+
+extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda2(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+
+ mul_mat_vec_q<2, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
+ (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+}
+
+extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda2(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+
+ mul_mat_vec_q<2, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
+ (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+}
+
+// batch size = 3
+extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda3(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+
+ mul_mat_vec_q<3, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
+ (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+}
+
+extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda3(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+
+ mul_mat_vec_q<3, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
+ (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+}
+
+extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda3(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+
+ mul_mat_vec_q<3, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
+ (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+}
+
+extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda3(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+
+ mul_mat_vec_q<3, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
+ (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+}
+
+extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda3(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+
+ mul_mat_vec_q<3, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
+ (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+}
+
+extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda3(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+
+ mul_mat_vec_q<3, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
+ (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+}
+
+extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda3(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+
+ mul_mat_vec_q<3, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
+ (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+}
+
+extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda3(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+
+ mul_mat_vec_q<3, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
+ (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+}
+
+extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda3(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+
+ mul_mat_vec_q<3, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
+ (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+}
+
+extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda3(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+
+ mul_mat_vec_q<3, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
+ (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+}
+
+// batch size = 4
+extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda4(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+
+ mul_mat_vec_q<4, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
+ (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+}
+
+extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda4(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+
+ mul_mat_vec_q<4, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
+ (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+}
+
+extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda4(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+
+ mul_mat_vec_q<4, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
+ (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+}
+
+extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda4(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+
+ mul_mat_vec_q<4, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
+ (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+}
+
+extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda4(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+
+ mul_mat_vec_q<4, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
+ (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+}
+
+extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda4(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+
+ mul_mat_vec_q<4, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
+ (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+}
+
+extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda4(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+
+ mul_mat_vec_q<4, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
+ (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+}
+
+extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda4(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+
+ mul_mat_vec_q<4, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
+ (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+}
+
+extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda4(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+
+ mul_mat_vec_q<4, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
+ (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+}
+
+extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda4(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+
+ mul_mat_vec_q<4, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
+ (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+}
+
extern "C" __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) {
const int ix = blockDim.x*blockIdx.x + threadIdx.x;