diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-04-18 08:36:43 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-18 08:36:43 +0200 |
commit | 8de0ce6cba823c53344ebdee028a13f8d564dee0 (patch) | |
tree | bcc131089fb9c7ee1d6c784cd186db991357cb17 /candle-core/src/quantized | |
parent | ce6d08df9484f1ccc45e32dcc4608c48b7c4194e (diff) | |
download | candle-8de0ce6cba823c53344ebdee028a13f8d564dee0.tar.gz candle-8de0ce6cba823c53344ebdee028a13f8d564dee0.tar.bz2 candle-8de0ce6cba823c53344ebdee028a13f8d564dee0.zip |
Add more QMMV cuda kernels. (#2077)
* Add more QMMV cuda kernels.
* Enable the new kernels.
* Adapt the testing.
Diffstat (limited to 'candle-core/src/quantized')
-rw-r--r-- | candle-core/src/quantized/cuda.rs | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index d6a61682..5481ca3c 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -178,8 +178,8 @@ fn mul_mat_vec_via_q8_1( if y.len() != ncols * b_size { crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len()) } - if b_size == 0 || b_size > 4 { - crate::bail!("only bsize between 1 and 4 are supported, got {b_size}") + if b_size == 0 || b_size > 8 { + crate::bail!("only bsize between 1 and 8 are supported, got {b_size}") } // Start by quantizing y let ncols_padded = pad(ncols, MATRIX_ROW_PADDING); @@ -204,14 +204,16 @@ fn mul_mat_vec_via_q8_1( let kernel_name = format!("{kernel_name}{b_size}"); let func = dev.get_or_load_func(&kernel_name, candle_kernels::QUANTIZED)?; let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? }; - let nblocks = if b_size == 1 { - nrows as u32 - } else { - (nrows as u32 + 1) / 2 + // https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98 + let (nblocks, nwarps) = match b_size { + 1 => (nrows as u32, 4), + 2..=4 => ((nrows as u32 + 1) / 2, 4), + 5..=8 => ((nrows as u32 + 1) / 2, 2), + _ => crate::bail!("unexpected bsize {b_size}"), }; let cfg = cudarc::driver::LaunchConfig { grid_dim: (nblocks, 1, 1), - block_dim: (WARP_SIZE as u32, 4, 1), + block_dim: (WARP_SIZE as u32, nwarps, 1), shared_mem_bytes: 0, }; @@ -398,7 +400,7 @@ impl QCudaStorage { let max_bm = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) { 1 } else { - 4 + 8 }; let use_vec_kernel = match layout.shape().dims() { [b, m, _k] => b * m <= max_bm, |