summaryrefslogtreecommitdiff
path: root/candle-core/src/quantized
diff options
context:
space:
mode:
authorlaurent <laurent.mazare@gmail.com>2024-02-29 10:54:01 +0100
committerlaurent <laurent.mazare@gmail.com>2024-02-29 10:54:01 +0100
commit2c95b7394a30c11e6f3bb0c452d53e5ffef19737 (patch)
tree771b75eb970e8777f959199bac241f37efc723b7 /candle-core/src/quantized
parent4fd00b890036ef67391a9cc03f896247d0a75711 (diff)
downloadcandle-2c95b7394a30c11e6f3bb0c452d53e5ffef19737.tar.gz
candle-2c95b7394a30c11e6f3bb0c452d53e5ffef19737.tar.bz2
candle-2c95b7394a30c11e6f3bb0c452d53e5ffef19737.zip
Handle Q5_0 and Q5_1 quants in cuda.
Diffstat (limited to 'candle-core/src/quantized')
-rw-r--r--candle-core/src/quantized/cuda.rs54
1 files changed, 38 insertions, 16 deletions
diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs
index 84af483d..5b684573 100644
--- a/candle-core/src/quantized/cuda.rs
+++ b/candle-core/src/quantized/cuda.rs
@@ -16,6 +16,7 @@ pub const MMQ_Y_Q4_0_AMPERE: usize = 32;
pub const NWARPS_Q4_0_AMPERE: usize = 4;
pub const GGML_CUDA_MMV_X: usize = 32;
pub const GGML_CUDA_MMV_Y: usize = 1;
+pub const CUDA_DEQUANTIZE_BLOCK_SIZE: usize = 256;
fn dequantize(
data: &CudaSlice<u8>,
@@ -25,28 +26,46 @@ fn dequantize(
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
- let (kernel_name, is_k, block_dim) = match dtype {
- GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32),
- GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32),
- GgmlDType::Q5_0 => ("dequantize_block_q5_0", false, 32),
- GgmlDType::Q5_1 => ("dequantize_block_q5_1", false, 32),
- GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32),
- GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64),
- GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64),
- GgmlDType::Q4K => ("dequantize_block_q4_K", true, 32),
- GgmlDType::Q5K => ("dequantize_block_q5_K", true, 64),
- GgmlDType::Q6K => ("dequantize_block_q6_K", true, 64),
- GgmlDType::Q8K => ("dequantize_block_q8_K", true, 32),
+ let nb = (elem_count + 255) / 256;
+ let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
+ GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32, nb),
+ GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32, nb),
+ GgmlDType::Q5_0 => {
+ let nb = (elem_count + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1)
+ / (2 * CUDA_DEQUANTIZE_BLOCK_SIZE);
+ (
+ "dequantize_block_q5_0",
+ false,
+ CUDA_DEQUANTIZE_BLOCK_SIZE,
+ nb,
+ )
+ }
+ GgmlDType::Q5_1 => {
+ let nb = (elem_count + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1)
+ / (2 * CUDA_DEQUANTIZE_BLOCK_SIZE);
+ (
+ "dequantize_block_q5_1",
+ false,
+ CUDA_DEQUANTIZE_BLOCK_SIZE,
+ nb,
+ )
+ }
+ GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32, nb),
+ GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64, nb),
+ GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64, nb),
+ GgmlDType::Q4K => ("dequantize_block_q4_K", true, 32, nb),
+ GgmlDType::Q5K => ("dequantize_block_q5_K", true, 64, nb),
+ GgmlDType::Q6K => ("dequantize_block_q6_K", true, 64, nb),
+ GgmlDType::Q8K => ("dequantize_block_q8_K", true, 32, nb),
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = dev.alloc_zeros::<f32>(elem_count).w()?;
- let nb = (elem_count + 255) / 256;
// See e.g.
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
let cfg = cudarc::driver::LaunchConfig {
- grid_dim: (nb as u32, 1, 1),
- block_dim: (block_dim, 1, 1),
+ grid_dim: (num_blocks as u32, 1, 1),
+ block_dim: (block_dim as u32, 1, 1),
shared_mem_bytes: 0,
};
@@ -54,7 +73,10 @@ fn dequantize(
let params = (data, &dst);
unsafe { func.launch(cfg, params) }.w()?;
} else {
- let nb32 = elem_count / 32;
+ let nb32 = match dtype {
+ GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
+ _ => elem_count / 32,
+ };
let params = (data, &dst, nb32 as i32);
unsafe { func.launch(cfg, params) }.w()?;
}