diff options
Diffstat (limited to 'candle-core/src/quantized/cuda.rs')
-rw-r--r-- | candle-core/src/quantized/cuda.rs | 54 |
1 files changed, 38 insertions, 16 deletions
diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index 84af483d..5b684573 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -16,6 +16,7 @@ pub const MMQ_Y_Q4_0_AMPERE: usize = 32; pub const NWARPS_Q4_0_AMPERE: usize = 4; pub const GGML_CUDA_MMV_X: usize = 32; pub const GGML_CUDA_MMV_Y: usize = 1; +pub const CUDA_DEQUANTIZE_BLOCK_SIZE: usize = 256; fn dequantize( data: &CudaSlice<u8>, @@ -25,28 +26,46 @@ fn dequantize( ) -> Result<CudaStorage> { use cudarc::driver::LaunchAsync; - let (kernel_name, is_k, block_dim) = match dtype { - GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32), - GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32), - GgmlDType::Q5_0 => ("dequantize_block_q5_0", false, 32), - GgmlDType::Q5_1 => ("dequantize_block_q5_1", false, 32), - GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32), - GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64), - GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64), - GgmlDType::Q4K => ("dequantize_block_q4_K", true, 32), - GgmlDType::Q5K => ("dequantize_block_q5_K", true, 64), - GgmlDType::Q6K => ("dequantize_block_q6_K", true, 64), - GgmlDType::Q8K => ("dequantize_block_q8_K", true, 32), + let nb = (elem_count + 255) / 256; + let (kernel_name, is_k, block_dim, num_blocks) = match dtype { + GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32, nb), + GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32, nb), + GgmlDType::Q5_0 => { + let nb = (elem_count + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1) + / (2 * CUDA_DEQUANTIZE_BLOCK_SIZE); + ( + "dequantize_block_q5_0", + false, + CUDA_DEQUANTIZE_BLOCK_SIZE, + nb, + ) + } + GgmlDType::Q5_1 => { + let nb = (elem_count + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1) + / (2 * CUDA_DEQUANTIZE_BLOCK_SIZE); + ( + "dequantize_block_q5_1", + false, + CUDA_DEQUANTIZE_BLOCK_SIZE, + nb, + ) + } + GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32, nb), + GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64, nb), + GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64, nb), + GgmlDType::Q4K => ("dequantize_block_q4_K", true, 32, nb), + GgmlDType::Q5K => ("dequantize_block_q5_K", true, 64, nb), + GgmlDType::Q6K => ("dequantize_block_q6_K", true, 64, nb), + GgmlDType::Q8K => ("dequantize_block_q8_K", true, 32, nb), _ => crate::bail!("unsupported dtype for dequantize {dtype:?}"), }; let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?; let dst = dev.alloc_zeros::<f32>(elem_count).w()?; - let nb = (elem_count + 255) / 256; // See e.g. // https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270 let cfg = cudarc::driver::LaunchConfig { - grid_dim: (nb as u32, 1, 1), - block_dim: (block_dim, 1, 1), + grid_dim: (num_blocks as u32, 1, 1), + block_dim: (block_dim as u32, 1, 1), shared_mem_bytes: 0, }; @@ -54,7 +73,10 @@ fn dequantize( let params = (data, &dst); unsafe { func.launch(cfg, params) }.w()?; } else { - let nb32 = elem_count / 32; + let nb32 = match dtype { + GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count, + _ => elem_count / 32, + }; let params = (data, &dst, nb32 as i32); unsafe { func.launch(cfg, params) }.w()?; } |