diff options
-rw-r--r-- | candle-core/src/quantized/cuda.rs | 64 | ||||
-rw-r--r-- | candle-core/tests/quantized_tests.rs | 36 | ||||
-rw-r--r-- | candle-kernels/src/quantized.cu | 264 |
3 files changed, 335 insertions, 29 deletions
diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index 54b1da41..d6a61682 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -166,6 +166,7 @@ fn mul_mat_vec_via_q8_1( dtype: GgmlDType, ncols: usize, nrows: usize, + b_size: usize, dev: &CudaDevice, ) -> Result<CudaStorage> { use cudarc::driver::LaunchAsync; @@ -174,14 +175,18 @@ fn mul_mat_vec_via_q8_1( if data_elems < ncols * nrows { crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems) } - if y.len() != ncols { + if y.len() != ncols * b_size { crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len()) } + if b_size == 0 || b_size > 4 { + crate::bail!("only bsize between 1 and 4 are supported, got {b_size}") + } // Start by quantizing y let ncols_padded = pad(ncols, MATRIX_ROW_PADDING); - let y_size_in_bytes = ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size(); + let y_size_in_bytes = + b_size * ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size(); let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? }; - quantize_q8_1(y, &mut y_q8_1, ncols, 1, dev)?; + quantize_q8_1(y, &mut y_q8_1, ncols, b_size, dev)?; let kernel_name = match dtype { GgmlDType::Q4_0 => "mul_mat_vec_q4_0_q8_1_cuda", @@ -196,10 +201,16 @@ fn mul_mat_vec_via_q8_1( GgmlDType::Q6K => "mul_mat_vec_q6_K_q8_1_cuda", _ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"), }; - let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?; - let dst = unsafe { dev.alloc::<f32>(nrows).w()? }; + let kernel_name = format!("{kernel_name}{b_size}"); + let func = dev.get_or_load_func(&kernel_name, candle_kernels::QUANTIZED)?; + let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? }; + let nblocks = if b_size == 1 { + nrows as u32 + } else { + (nrows as u32 + 1) / 2 + }; let cfg = cudarc::driver::LaunchConfig { - grid_dim: (nrows as u32, 1, 1), + grid_dim: (nblocks, 1, 1), block_dim: (WARP_SIZE as u32, 4, 1), shared_mem_bytes: 0, }; @@ -210,7 +221,7 @@ fn mul_mat_vec_via_q8_1( &dst, /* ncols_x */ ncols as i32, /* nrows_x */ nrows as i32, - /* nrows_y */ ncols as i32, + /* nrows_y */ ncols_padded as i32, /* nrows_dst */ nrows as i32, ); unsafe { func.launch(cfg, params) }.w()?; @@ -384,7 +395,17 @@ impl QCudaStorage { storage: &CudaStorage, layout: &crate::Layout, ) -> Result<(CudaStorage, crate::Shape)> { - if matches!(layout.shape().dims(), [1, 1, _] | [1, _]) { + let max_bm = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) { + 1 + } else { + 4 + }; + let use_vec_kernel = match layout.shape().dims() { + [b, m, _k] => b * m <= max_bm, + [b, _k] => *b <= max_bm, + _ => false, + }; + if use_vec_kernel { self.dequantize_matmul_vec(self_shape, storage, layout) } else { self.dequantize_matmul(self_shape, storage, layout) @@ -405,25 +426,31 @@ impl QCudaStorage { Some((o1, o2)) => rhs.slice(o1..o2), None => Err(crate::Error::RequiresContiguous { op: "dmmv" }.bt())?, }; - let (with_batch, k) = match rhs_l.shape().dims() { - [1, 1, k] => (true, k), - [1, k] => (false, k), + let (b_size, k) = match rhs_l.shape().dims() { + [b, m, k] => (b * m, *k), + [b, k] => (*b, *k), _ => crate::bail!("unexpected rhs shape in dmmv {:?}", rhs_l.shape()), }; - if ncols != *k { + if ncols != k { crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape()) } let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) { dequantize_mul_mat_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())? } else { - mul_mat_vec_via_q8_1(&self.data, &rhs, self.dtype, ncols, nrows, self.device())? - }; - let out_shape = if with_batch { - vec![1, 1, nrows] - } else { - vec![1, nrows] + mul_mat_vec_via_q8_1( + &self.data, + &rhs, + self.dtype, + ncols, + nrows, + b_size, + self.device(), + )? }; + let mut out_shape = rhs_l.shape().dims().to_vec(); + out_shape.pop(); + out_shape.push(nrows); Ok((out, out_shape.into())) } @@ -522,6 +549,7 @@ mod test { /* dtype */ GgmlDType::Q4_0, /* ncols */ ncols, /* nrows */ 1, + /* b_size */ 1, &dev, )?; let vs = cuda_storage.as_cuda_slice::<f32>()?; diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index 223accc4..157f2f8d 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -170,12 +170,46 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> { let res2 = matmul.forward(&lhs2)?; let res2 = res2.i(1)?; let diff = (res - res2)?.abs()?.sum_all()?.to_vec0::<f32>()?; - assert_eq!(diff, 0.); + if device.is_cuda() { + assert!(diff < 0.1); + } else { + assert_eq!(diff, 0.); + } + Ok(()) +} + +fn qmm_batch(dev: &Device) -> Result<()> { + let (lhs, rhs, _mm) = get_random_tensors(2, 256, 6, dev)?; + let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q2K)?; + let rhs = quantized::QMatMul::from_qtensor(rhs)?; + let mm = rhs.forward(&lhs)?; + assert_eq!(mm.shape().dims(), [2, 6]); + let lhs2 = Tensor::cat(&[&lhs, &lhs], 0)?; + let mm2 = rhs.forward(&lhs2)?; + assert_eq!(mm2.shape().dims(), [4, 6]); + let diff2 = (mm2.i(2..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?; + assert_eq!(diff2, 0.0); + let lhs3 = Tensor::cat(&[&lhs2, &lhs], 0)?; + let mm3 = rhs.forward(&lhs3)?; + assert_eq!(mm3.shape().dims(), [6, 6]); + let diff3 = (mm3.i(2..4)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?; + if dev.is_cuda() { + assert!(diff3 < 1e-4) + } else { + assert_eq!(diff3, 0.0) + }; + let diff3 = (mm3.i(4..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?; + if dev.is_cuda() { + assert!(diff3 < 1e-4) + } else { + assert_eq!(diff3, 0.0) + }; Ok(()) } test_device!(quantized_matmul, qmm_cpu, qmm_cuda, qmm_metal); test_device!(quantized_matmul_neg, qmm_n_cpu, qmm_n_cuda, qmm_n_metal); +test_device!(qmm_batch, qmm_b_cpu, qmm_b_cuda, qmm_b_metal); fn quantize_q4_0(device: &Device) -> Result<()> { let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>(); diff --git a/candle-kernels/src/quantized.cu b/candle-kernels/src/quantized.cu index fa38f325..7e3e7b4c 100644 --- a/candle-kernels/src/quantized.cu +++ b/candle-kernels/src/quantized.cu @@ -2648,7 +2648,8 @@ static __device__ void mul_mat_vec_q( } } -extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda( +// batch size = 1 +extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda1( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { @@ -2656,7 +2657,7 @@ extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda( (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); } -extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda( +extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda1( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { @@ -2664,7 +2665,7 @@ extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda( (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); } -extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda( +extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda1( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { @@ -2672,7 +2673,7 @@ extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda( (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); } -extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda( +extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda1( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { @@ -2680,7 +2681,7 @@ extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda( (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); } -extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda( +extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda1( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { @@ -2688,7 +2689,7 @@ extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda( (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); } -extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda( +extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda1( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { @@ -2696,7 +2697,7 @@ extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda( (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); } -extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda( +extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda1( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { @@ -2704,7 +2705,7 @@ extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda( (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); } -extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda( +extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda1( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { @@ -2712,7 +2713,7 @@ extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda( (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); } -extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda( +extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda1( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { @@ -2720,7 +2721,7 @@ extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda( (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); } -extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda( +extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda1( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { @@ -2728,6 +2729,249 @@ extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda( (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); } +// batch size = 2 +extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda2( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<2, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda2( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<2, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda2( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<2, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda2( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<2, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda2( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<2, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda2( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<2, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda2( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<2, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda2( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<2, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda2( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<2, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda2( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<2, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +// batch size = 3 +extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda3( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<3, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda3( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<3, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda3( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<3, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda3( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<3, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda3( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<3, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda3( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<3, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda3( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<3, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda3( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<3, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda3( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<3, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda3( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<3, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +// batch size = 4 +extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda4( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<4, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda4( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<4, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda4( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<4, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda4( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<4, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda4( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<4, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda4( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<4, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda4( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<4, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda4( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<4, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda4( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<4, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda4( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<4, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + extern "C" __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) { const int ix = blockDim.x*blockIdx.x + threadIdx.x; |