diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-04-01 00:15:48 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-01 00:15:48 +0200 |
commit | cd29c7ccd420a840d883361c290ee92d06b9b96c (patch) | |
tree | d387a1f1af623de2e50751d493d541eb3789684c /candle-core/src/quantized/cuda.rs | |
parent | f9954b73bac9fed91a9a08d952adc1cfb836a568 (diff) | |
download | candle-cd29c7ccd420a840d883361c290ee92d06b9b96c.tar.gz candle-cd29c7ccd420a840d883361c290ee92d06b9b96c.tar.bz2 candle-cd29c7ccd420a840d883361c290ee92d06b9b96c.zip |
More ggml cuda kernels (#1977)
* Add more cuda kernels for quantized matmul.
* Add the vec-dot bits.
* Expose the quantized matmul-vec kernels.
* Also include the quantize-q8-1 kernel.
* Glue code for the q8-1 quantization.
* mm-vec product via q8-1 quantization.
* Add a test.
* Add a mm test.
* Get the test to return some sensible results.
* Also test dmmv.
* Fix the launch params.
* Allow for tweaking the force_dmmv parameter while it's experimental.
Diffstat (limited to 'candle-core/src/quantized/cuda.rs')
-rw-r--r-- | candle-core/src/quantized/cuda.rs | 154 |
1 files changed, 147 insertions, 7 deletions
diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index c90cf576..a8f0d622 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -2,7 +2,7 @@ use super::{GgmlDType, QStorage}; use crate::{backend::BackendDevice, cuda_backend::WrapErr}; use crate::{CudaDevice, CudaStorage, Result}; -use cudarc::driver::{CudaSlice, DeviceSlice}; +use cudarc::driver::{CudaSlice, CudaView, DeviceSlice}; pub struct QCudaStorage { data: CudaSlice<u8>, @@ -10,13 +10,43 @@ pub struct QCudaStorage { device: CudaDevice, } +static FORCE_DMMV: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(true); + +pub fn set_force_dmmv(f: bool) { + FORCE_DMMV.store(f, std::sync::atomic::Ordering::Relaxed) +} + pub const WARP_SIZE: usize = 32; pub const MMQ_X_Q4_0_AMPERE: usize = 4; pub const MMQ_Y_Q4_0_AMPERE: usize = 32; pub const NWARPS_Q4_0_AMPERE: usize = 4; pub const GGML_CUDA_MMV_X: usize = 32; pub const GGML_CUDA_MMV_Y: usize = 1; +pub const CUDA_QUANTIZE_BLOCK_SIZE: usize = 256; pub const CUDA_DEQUANTIZE_BLOCK_SIZE: usize = 256; +pub const MATRIX_ROW_PADDING: usize = 512; + +fn quantize_q8_1( + src: &CudaView<f32>, + dst: &mut CudaSlice<u8>, + elem_count: usize, + dev: &CudaDevice, +) -> Result<()> { + use cudarc::driver::LaunchAsync; + + let kx = elem_count; + let kx_padded = (kx + MATRIX_ROW_PADDING - 1) / MATRIX_ROW_PADDING * MATRIX_ROW_PADDING; + let num_blocks = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; + let func = dev.get_or_load_func("quantize_q8_1", candle_kernels::QUANTIZED)?; + let cfg = cudarc::driver::LaunchConfig { + grid_dim: (num_blocks as u32, 1, 1), + block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1), + shared_mem_bytes: 0, + }; + let params = (src, dst, kx as i32, kx_padded as i32); + unsafe { func.launch(cfg, params) }.w()?; + Ok(()) +} fn dequantize( data: &CudaSlice<u8>, @@ -60,7 +90,7 @@ fn dequantize( _ => crate::bail!("unsupported dtype for dequantize {dtype:?}"), }; let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?; - let dst = dev.alloc_zeros::<f32>(elem_count).w()?; + let dst = unsafe { dev.alloc::<f32>(elem_count).w()? }; // See e.g. // https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270 let cfg = cudarc::driver::LaunchConfig { @@ -83,9 +113,9 @@ fn dequantize( Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone())) } -fn dequantize_mut_mal_vec( +fn dequantize_mul_mat_vec( data: &CudaSlice<u8>, - y: &cudarc::driver::CudaView<f32>, + y: &CudaView<f32>, dtype: GgmlDType, ncols: usize, nrows: usize, @@ -107,7 +137,7 @@ fn dequantize_mut_mal_vec( _ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"), }; let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?; - let dst = dev.alloc_zeros::<f32>(nrows).w()?; + let dst = unsafe { dev.alloc::<f32>(nrows).w()? }; let block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; let cfg = cudarc::driver::LaunchConfig { grid_dim: (block_num_y as u32, 1, 1), @@ -120,6 +150,56 @@ fn dequantize_mut_mal_vec( Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone())) } +fn mul_mat_vec_via_q8_1( + data: &CudaSlice<u8>, + y: &CudaView<f32>, + dtype: GgmlDType, + ncols: usize, + nrows: usize, + dev: &CudaDevice, +) -> Result<CudaStorage> { + use cudarc::driver::LaunchAsync; + + // Start by quantizing y + let ncols_padded = (ncols + MATRIX_ROW_PADDING - 1) / MATRIX_ROW_PADDING * MATRIX_ROW_PADDING; + let y_size_in_bytes = ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size(); + let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? }; + quantize_q8_1(y, &mut y_q8_1, ncols, dev)?; + + let kernel_name = match dtype { + GgmlDType::Q4_0 => "mul_mat_vec_q4_0_q8_1_cuda", + GgmlDType::Q4_1 => "mul_mat_vec_q4_1_q8_1_cuda", + GgmlDType::Q5_0 => "mul_mat_vec_q5_0_q8_1_cuda", + GgmlDType::Q5_1 => "mul_mat_vec_q5_1_q8_1_cuda", + GgmlDType::Q8_0 => "mul_mat_vec_q8_0_q8_1_cuda", + GgmlDType::Q2K => "mul_mat_vec_q2_K_q8_1_cuda", + GgmlDType::Q3K => "mul_mat_vec_q3_K_q8_1_cuda", + GgmlDType::Q4K => "mul_mat_vec_q4_K_q8_1_cuda", + GgmlDType::Q5K => "mul_mat_vec_q5_K_q8_1_cuda", + GgmlDType::Q6K => "mul_mat_vec_q6_K_q8_1_cuda", + _ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"), + }; + let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?; + let dst = unsafe { dev.alloc::<f32>(nrows).w()? }; + let cfg = cudarc::driver::LaunchConfig { + grid_dim: (nrows as u32, 1, 1), + block_dim: (WARP_SIZE as u32, 4, 1), + shared_mem_bytes: 0, + }; + + let params = ( + data, + &y_q8_1, + &dst, + /* ncols_x */ ncols as i32, + /* nrows_x */ nrows as i32, + /* nrows_y */ ncols as i32, + /* nrows_dst */ nrows as i32, + ); + unsafe { func.launch(cfg, params) }.w()?; + Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone())) +} + impl QCudaStorage { pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> { let size_in_bytes = el_count * dtype.type_size() / dtype.block_size(); @@ -285,8 +365,11 @@ impl QCudaStorage { crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape()) } - let out = - dequantize_mut_mal_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?; + let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) { + dequantize_mul_mat_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())? + } else { + mul_mat_vec_via_q8_1(&self.data, &rhs, self.dtype, ncols, nrows, self.device())? + }; let out_shape = if with_batch { vec![1, 1, nrows] } else { @@ -341,3 +424,60 @@ pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>( dtype: T::DTYPE, })) } + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn cuda_quantize_q8_1() -> Result<()> { + let dev = CudaDevice::new(0)?; + let el = 256; + let el_padded = (el + MATRIX_ROW_PADDING - 1) / MATRIX_ROW_PADDING * MATRIX_ROW_PADDING; + let y_size_in_bytes = + el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size(); + let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? }; + let vs: Vec<f32> = (0..el).map(|v| v as f32).collect(); + let y = dev.htod_sync_copy(&vs).w()?; + quantize_q8_1(&y.slice(..), &mut y_q8_1, el, &dev)?; + Ok(()) + } + + #[test] + fn cuda_mmv_q8_1() -> Result<()> { + let dev = CudaDevice::new(0)?; + let ncols = 256; + let vs: Vec<f32> = (0..ncols).map(|v| v as f32).collect(); + let y = dev.htod_sync_copy(&vs).w()?; + let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?; + xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?; + let cuda_storage = mul_mat_vec_via_q8_1( + &xs.data, + &y.slice(..), + /* dtype */ GgmlDType::Q4_0, + /* ncols */ ncols, + /* nrows */ 1, + &dev, + )?; + let vs = cuda_storage.as_cuda_slice::<f32>()?; + let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap(); + assert_eq!(vs.len(), 1); + // for n = 255, n.(n+1).(2n+1) / 6 = 5559680 + // Q8 means 1/256 precision. + assert_eq!(vs[0], 5561664.5); + + let cuda_storage = dequantize_mul_mat_vec( + &xs.data, + &y.slice(..), + /* dtype */ GgmlDType::Q4_0, + /* ncols */ ncols, + /* nrows */ 1, + &dev, + )?; + let vs = cuda_storage.as_cuda_slice::<f32>()?; + let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap(); + assert_eq!(vs.len(), 1); + assert_eq!(vs[0], 5561851.0); + Ok(()) + } +} |