summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/quantized/cuda.rs54
-rw-r--r--candle-core/tests/quantized_tests.rs8
-rw-r--r--candle-kernels/src/quantized.cu16
3 files changed, 47 insertions, 31 deletions
diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs
index 84af483d..5b684573 100644
--- a/candle-core/src/quantized/cuda.rs
+++ b/candle-core/src/quantized/cuda.rs
@@ -16,6 +16,7 @@ pub const MMQ_Y_Q4_0_AMPERE: usize = 32;
pub const NWARPS_Q4_0_AMPERE: usize = 4;
pub const GGML_CUDA_MMV_X: usize = 32;
pub const GGML_CUDA_MMV_Y: usize = 1;
+pub const CUDA_DEQUANTIZE_BLOCK_SIZE: usize = 256;
fn dequantize(
data: &CudaSlice<u8>,
@@ -25,28 +26,46 @@ fn dequantize(
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
- let (kernel_name, is_k, block_dim) = match dtype {
- GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32),
- GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32),
- GgmlDType::Q5_0 => ("dequantize_block_q5_0", false, 32),
- GgmlDType::Q5_1 => ("dequantize_block_q5_1", false, 32),
- GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32),
- GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64),
- GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64),
- GgmlDType::Q4K => ("dequantize_block_q4_K", true, 32),
- GgmlDType::Q5K => ("dequantize_block_q5_K", true, 64),
- GgmlDType::Q6K => ("dequantize_block_q6_K", true, 64),
- GgmlDType::Q8K => ("dequantize_block_q8_K", true, 32),
+ let nb = (elem_count + 255) / 256;
+ let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
+ GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32, nb),
+ GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32, nb),
+ GgmlDType::Q5_0 => {
+ let nb = (elem_count + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1)
+ / (2 * CUDA_DEQUANTIZE_BLOCK_SIZE);
+ (
+ "dequantize_block_q5_0",
+ false,
+ CUDA_DEQUANTIZE_BLOCK_SIZE,
+ nb,
+ )
+ }
+ GgmlDType::Q5_1 => {
+ let nb = (elem_count + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1)
+ / (2 * CUDA_DEQUANTIZE_BLOCK_SIZE);
+ (
+ "dequantize_block_q5_1",
+ false,
+ CUDA_DEQUANTIZE_BLOCK_SIZE,
+ nb,
+ )
+ }
+ GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32, nb),
+ GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64, nb),
+ GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64, nb),
+ GgmlDType::Q4K => ("dequantize_block_q4_K", true, 32, nb),
+ GgmlDType::Q5K => ("dequantize_block_q5_K", true, 64, nb),
+ GgmlDType::Q6K => ("dequantize_block_q6_K", true, 64, nb),
+ GgmlDType::Q8K => ("dequantize_block_q8_K", true, 32, nb),
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = dev.alloc_zeros::<f32>(elem_count).w()?;
- let nb = (elem_count + 255) / 256;
// See e.g.
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
let cfg = cudarc::driver::LaunchConfig {
- grid_dim: (nb as u32, 1, 1),
- block_dim: (block_dim, 1, 1),
+ grid_dim: (num_blocks as u32, 1, 1),
+ block_dim: (block_dim as u32, 1, 1),
shared_mem_bytes: 0,
};
@@ -54,7 +73,10 @@ fn dequantize(
let params = (data, &dst);
unsafe { func.launch(cfg, params) }.w()?;
} else {
- let nb32 = elem_count / 32;
+ let nb32 = match dtype {
+ GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
+ _ => elem_count / 32,
+ };
let params = (data, &dst, nb32 as i32);
unsafe { func.launch(cfg, params) }.w()?;
}
diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs
index 5f7e4825..d767531a 100644
--- a/candle-core/tests/quantized_tests.rs
+++ b/candle-core/tests/quantized_tests.rs
@@ -231,10 +231,6 @@ fn quantize_q4_1(device: &Device) -> Result<()> {
}
fn quantize_q5_0(device: &Device) -> Result<()> {
- // TODO Enable this later when we enable cuda.
- if device.is_cuda() {
- return Ok(());
- }
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?;
@@ -261,10 +257,6 @@ fn quantize_q5_0(device: &Device) -> Result<()> {
}
fn quantize_q5_1(device: &Device) -> Result<()> {
- // TODO Enable this later when we enable cuda.
- if device.is_cuda() {
- return Ok(());
- }
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?;
diff --git a/candle-kernels/src/quantized.cu b/candle-kernels/src/quantized.cu
index 4d32f6fa..762395d8 100644
--- a/candle-kernels/src/quantized.cu
+++ b/candle-kernels/src/quantized.cu
@@ -575,7 +575,7 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
-static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
+static __device__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
const int i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
if (i >= k) {
@@ -595,12 +595,6 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
y[iybs + iqs + y_offset] = v.y;
}
-template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
-static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
- const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
- dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
-}
-
extern "C" __global__ void dequantize_block_q4_0(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
const int i = blockIdx.x;
@@ -910,6 +904,14 @@ extern "C" __global__ void dequantize_block_q8_K(const void * __restrict__ vx, f
#endif
}
+extern "C" __global__ void dequantize_block_q5_0(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
+ return dequantize_block<QK5_0, QR5_0, dequantize_q5_0>(vx, yy, nb32);
+}
+
+extern "C" __global__ void dequantize_block_q5_1(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
+ return dequantize_block<QK5_1, QR5_1, dequantize_q5_1>(vx, yy, nb32);
+}
+
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
static __device__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {