summaryrefslogtreecommitdiff
path: root/candle-kernels/src
diff options
context:
space:
mode:
authorlaurent <laurent.mazare@gmail.com>2024-02-29 10:54:01 +0100
committerlaurent <laurent.mazare@gmail.com>2024-02-29 10:54:01 +0100
commit2c95b7394a30c11e6f3bb0c452d53e5ffef19737 (patch)
tree771b75eb970e8777f959199bac241f37efc723b7 /candle-kernels/src
parent4fd00b890036ef67391a9cc03f896247d0a75711 (diff)
downloadcandle-2c95b7394a30c11e6f3bb0c452d53e5ffef19737.tar.gz
candle-2c95b7394a30c11e6f3bb0c452d53e5ffef19737.tar.bz2
candle-2c95b7394a30c11e6f3bb0c452d53e5ffef19737.zip
Handle Q5_0 and Q5_1 quants in cuda.
Diffstat (limited to 'candle-kernels/src')
-rw-r--r--candle-kernels/src/quantized.cu16
1 files changed, 9 insertions, 7 deletions
diff --git a/candle-kernels/src/quantized.cu b/candle-kernels/src/quantized.cu
index 4d32f6fa..762395d8 100644
--- a/candle-kernels/src/quantized.cu
+++ b/candle-kernels/src/quantized.cu
@@ -575,7 +575,7 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
-static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
+static __device__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
const int i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
if (i >= k) {
@@ -595,12 +595,6 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
y[iybs + iqs + y_offset] = v.y;
}
-template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
-static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
- const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
- dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
-}
-
extern "C" __global__ void dequantize_block_q4_0(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
const int i = blockIdx.x;
@@ -910,6 +904,14 @@ extern "C" __global__ void dequantize_block_q8_K(const void * __restrict__ vx, f
#endif
}
+extern "C" __global__ void dequantize_block_q5_0(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
+ return dequantize_block<QK5_0, QR5_0, dequantize_q5_0>(vx, yy, nb32);
+}
+
+extern "C" __global__ void dequantize_block_q5_1(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
+ return dequantize_block<QK5_1, QR5_1, dequantize_q5_1>(vx, yy, nb32);
+}
+
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
static __device__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {