diff options
author | laurent <laurent.mazare@gmail.com> | 2024-02-29 10:54:01 +0100 |
---|---|---|
committer | laurent <laurent.mazare@gmail.com> | 2024-02-29 10:54:01 +0100 |
commit | 2c95b7394a30c11e6f3bb0c452d53e5ffef19737 (patch) | |
tree | 771b75eb970e8777f959199bac241f37efc723b7 /candle-kernels/src | |
parent | 4fd00b890036ef67391a9cc03f896247d0a75711 (diff) | |
download | candle-2c95b7394a30c11e6f3bb0c452d53e5ffef19737.tar.gz candle-2c95b7394a30c11e6f3bb0c452d53e5ffef19737.tar.bz2 candle-2c95b7394a30c11e6f3bb0c452d53e5ffef19737.zip |
Handle Q5_0 and Q5_1 quants in cuda.
Diffstat (limited to 'candle-kernels/src')
-rw-r--r-- | candle-kernels/src/quantized.cu | 16 |
1 files changed, 9 insertions, 7 deletions
diff --git a/candle-kernels/src/quantized.cu b/candle-kernels/src/quantized.cu index 4d32f6fa..762395d8 100644 --- a/candle-kernels/src/quantized.cu +++ b/candle-kernels/src/quantized.cu @@ -575,7 +575,7 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t> -static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) { +static __device__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) { const int i = 2*(blockDim.x*blockIdx.x + threadIdx.x); if (i >= k) { @@ -595,12 +595,6 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __ y[iybs + iqs + y_offset] = v.y; } -template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t> -static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) { - const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE); - dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k); -} - extern "C" __global__ void dequantize_block_q4_0(const void * __restrict__ vx, float * __restrict__ yy, int nb32) { const int i = blockIdx.x; @@ -910,6 +904,14 @@ extern "C" __global__ void dequantize_block_q8_K(const void * __restrict__ vx, f #endif } +extern "C" __global__ void dequantize_block_q5_0(const void * __restrict__ vx, float * __restrict__ yy, int nb32) { + return dequantize_block<QK5_0, QR5_0, dequantize_q5_0>(vx, yy, nb32); +} + +extern "C" __global__ void dequantize_block_q5_1(const void * __restrict__ vx, float * __restrict__ yy, int nb32) { + return dequantize_block<QK5_1, QR5_1, dequantize_q5_1>(vx, yy, nb32); +} + template <int qk, int qr, dequantize_kernel_t dequantize_kernel> static __device__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) { |