diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-04-16 21:30:51 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-16 21:30:51 +0200 |
commit | 2817643db9c687cacd330ad53385ae278d018c00 (patch) | |
tree | cf56d91fe4f0ae8288889e8711e66e86935c5a2e /candle-kernels | |
parent | 4d14777673c51b66535d6d716991038a86e3448c (diff) | |
download | candle-2817643db9c687cacd330ad53385ae278d018c00.tar.gz candle-2817643db9c687cacd330ad53385ae278d018c00.tar.bz2 candle-2817643db9c687cacd330ad53385ae278d018c00.zip |
Add the mmv kernels for small batch sizes. (#2075)
* Add the mmv kernels for smaller sizes.
* Support more mmv kernels.
* Use the new kernels.
* Fix the call.
* Silly fix.
* Improve the testing.
* Fix for dmmv.
* Add another dedicated test for the batching mmv.
Diffstat (limited to 'candle-kernels')
-rw-r--r-- | candle-kernels/src/quantized.cu | 264 |
1 files changed, 254 insertions, 10 deletions
diff --git a/candle-kernels/src/quantized.cu b/candle-kernels/src/quantized.cu index fa38f325..7e3e7b4c 100644 --- a/candle-kernels/src/quantized.cu +++ b/candle-kernels/src/quantized.cu @@ -2648,7 +2648,8 @@ static __device__ void mul_mat_vec_q( } } -extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda( +// batch size = 1 +extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda1( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { @@ -2656,7 +2657,7 @@ extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda( (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); } -extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda( +extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda1( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { @@ -2664,7 +2665,7 @@ extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda( (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); } -extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda( +extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda1( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { @@ -2672,7 +2673,7 @@ extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda( (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); } -extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda( +extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda1( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { @@ -2680,7 +2681,7 @@ extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda( (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); } -extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda( +extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda1( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { @@ -2688,7 +2689,7 @@ extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda( (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); } -extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda( +extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda1( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { @@ -2696,7 +2697,7 @@ extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda( (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); } -extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda( +extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda1( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { @@ -2704,7 +2705,7 @@ extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda( (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); } -extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda( +extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda1( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { @@ -2712,7 +2713,7 @@ extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda( (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); } -extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda( +extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda1( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { @@ -2720,7 +2721,7 @@ extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda( (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); } -extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda( +extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda1( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { @@ -2728,6 +2729,249 @@ extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda( (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); } +// batch size = 2 +extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda2( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<2, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda2( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<2, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda2( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<2, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda2( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<2, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda2( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<2, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda2( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<2, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda2( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<2, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda2( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<2, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda2( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<2, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda2( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<2, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +// batch size = 3 +extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda3( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<3, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda3( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<3, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda3( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<3, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda3( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<3, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda3( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<3, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda3( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<3, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda3( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<3, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda3( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<3, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda3( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<3, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda3( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<3, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +// batch size = 4 +extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda4( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<4, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda4( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<4, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda4( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<4, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda4( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<4, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda4( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<4, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda4( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<4, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda4( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<4, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda4( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<4, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda4( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<4, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda4( + const void * vx, const void * vy, float * dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + + mul_mat_vec_q<4, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1> + (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + extern "C" __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) { const int ix = blockDim.x*blockIdx.x + threadIdx.x; |