summaryrefslogtreecommitdiff
path: root/candle-core/src/quantized/cuda.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-01 00:15:48 +0200
committerGitHub <noreply@github.com>2024-04-01 00:15:48 +0200
commitcd29c7ccd420a840d883361c290ee92d06b9b96c (patch)
treed387a1f1af623de2e50751d493d541eb3789684c /candle-core/src/quantized/cuda.rs
parentf9954b73bac9fed91a9a08d952adc1cfb836a568 (diff)
downloadcandle-cd29c7ccd420a840d883361c290ee92d06b9b96c.tar.gz
candle-cd29c7ccd420a840d883361c290ee92d06b9b96c.tar.bz2
candle-cd29c7ccd420a840d883361c290ee92d06b9b96c.zip
More ggml cuda kernels (#1977)
* Add more cuda kernels for quantized matmul. * Add the vec-dot bits. * Expose the quantized matmul-vec kernels. * Also include the quantize-q8-1 kernel. * Glue code for the q8-1 quantization. * mm-vec product via q8-1 quantization. * Add a test. * Add a mm test. * Get the test to return some sensible results. * Also test dmmv. * Fix the launch params. * Allow for tweaking the force_dmmv parameter while it's experimental.
Diffstat (limited to 'candle-core/src/quantized/cuda.rs')
-rw-r--r--candle-core/src/quantized/cuda.rs154
1 files changed, 147 insertions, 7 deletions
diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs
index c90cf576..a8f0d622 100644
--- a/candle-core/src/quantized/cuda.rs
+++ b/candle-core/src/quantized/cuda.rs
@@ -2,7 +2,7 @@ use super::{GgmlDType, QStorage};
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
use crate::{CudaDevice, CudaStorage, Result};
-use cudarc::driver::{CudaSlice, DeviceSlice};
+use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};
pub struct QCudaStorage {
data: CudaSlice<u8>,
@@ -10,13 +10,43 @@ pub struct QCudaStorage {
device: CudaDevice,
}
+static FORCE_DMMV: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(true);
+
+pub fn set_force_dmmv(f: bool) {
+ FORCE_DMMV.store(f, std::sync::atomic::Ordering::Relaxed)
+}
+
pub const WARP_SIZE: usize = 32;
pub const MMQ_X_Q4_0_AMPERE: usize = 4;
pub const MMQ_Y_Q4_0_AMPERE: usize = 32;
pub const NWARPS_Q4_0_AMPERE: usize = 4;
pub const GGML_CUDA_MMV_X: usize = 32;
pub const GGML_CUDA_MMV_Y: usize = 1;
+pub const CUDA_QUANTIZE_BLOCK_SIZE: usize = 256;
pub const CUDA_DEQUANTIZE_BLOCK_SIZE: usize = 256;
+pub const MATRIX_ROW_PADDING: usize = 512;
+
+fn quantize_q8_1(
+ src: &CudaView<f32>,
+ dst: &mut CudaSlice<u8>,
+ elem_count: usize,
+ dev: &CudaDevice,
+) -> Result<()> {
+ use cudarc::driver::LaunchAsync;
+
+ let kx = elem_count;
+ let kx_padded = (kx + MATRIX_ROW_PADDING - 1) / MATRIX_ROW_PADDING * MATRIX_ROW_PADDING;
+ let num_blocks = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
+ let func = dev.get_or_load_func("quantize_q8_1", candle_kernels::QUANTIZED)?;
+ let cfg = cudarc::driver::LaunchConfig {
+ grid_dim: (num_blocks as u32, 1, 1),
+ block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1),
+ shared_mem_bytes: 0,
+ };
+ let params = (src, dst, kx as i32, kx_padded as i32);
+ unsafe { func.launch(cfg, params) }.w()?;
+ Ok(())
+}
fn dequantize(
data: &CudaSlice<u8>,
@@ -60,7 +90,7 @@ fn dequantize(
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
- let dst = dev.alloc_zeros::<f32>(elem_count).w()?;
+ let dst = unsafe { dev.alloc::<f32>(elem_count).w()? };
// See e.g.
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
let cfg = cudarc::driver::LaunchConfig {
@@ -83,9 +113,9 @@ fn dequantize(
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
-fn dequantize_mut_mal_vec(
+fn dequantize_mul_mat_vec(
data: &CudaSlice<u8>,
- y: &cudarc::driver::CudaView<f32>,
+ y: &CudaView<f32>,
dtype: GgmlDType,
ncols: usize,
nrows: usize,
@@ -107,7 +137,7 @@ fn dequantize_mut_mal_vec(
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
- let dst = dev.alloc_zeros::<f32>(nrows).w()?;
+ let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
let block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (block_num_y as u32, 1, 1),
@@ -120,6 +150,56 @@ fn dequantize_mut_mal_vec(
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
+fn mul_mat_vec_via_q8_1(
+ data: &CudaSlice<u8>,
+ y: &CudaView<f32>,
+ dtype: GgmlDType,
+ ncols: usize,
+ nrows: usize,
+ dev: &CudaDevice,
+) -> Result<CudaStorage> {
+ use cudarc::driver::LaunchAsync;
+
+ // Start by quantizing y
+ let ncols_padded = (ncols + MATRIX_ROW_PADDING - 1) / MATRIX_ROW_PADDING * MATRIX_ROW_PADDING;
+ let y_size_in_bytes = ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
+ let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
+ quantize_q8_1(y, &mut y_q8_1, ncols, dev)?;
+
+ let kernel_name = match dtype {
+ GgmlDType::Q4_0 => "mul_mat_vec_q4_0_q8_1_cuda",
+ GgmlDType::Q4_1 => "mul_mat_vec_q4_1_q8_1_cuda",
+ GgmlDType::Q5_0 => "mul_mat_vec_q5_0_q8_1_cuda",
+ GgmlDType::Q5_1 => "mul_mat_vec_q5_1_q8_1_cuda",
+ GgmlDType::Q8_0 => "mul_mat_vec_q8_0_q8_1_cuda",
+ GgmlDType::Q2K => "mul_mat_vec_q2_K_q8_1_cuda",
+ GgmlDType::Q3K => "mul_mat_vec_q3_K_q8_1_cuda",
+ GgmlDType::Q4K => "mul_mat_vec_q4_K_q8_1_cuda",
+ GgmlDType::Q5K => "mul_mat_vec_q5_K_q8_1_cuda",
+ GgmlDType::Q6K => "mul_mat_vec_q6_K_q8_1_cuda",
+ _ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
+ };
+ let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
+ let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
+ let cfg = cudarc::driver::LaunchConfig {
+ grid_dim: (nrows as u32, 1, 1),
+ block_dim: (WARP_SIZE as u32, 4, 1),
+ shared_mem_bytes: 0,
+ };
+
+ let params = (
+ data,
+ &y_q8_1,
+ &dst,
+ /* ncols_x */ ncols as i32,
+ /* nrows_x */ nrows as i32,
+ /* nrows_y */ ncols as i32,
+ /* nrows_dst */ nrows as i32,
+ );
+ unsafe { func.launch(cfg, params) }.w()?;
+ Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
+}
+
impl QCudaStorage {
pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> {
let size_in_bytes = el_count * dtype.type_size() / dtype.block_size();
@@ -285,8 +365,11 @@ impl QCudaStorage {
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape())
}
- let out =
- dequantize_mut_mal_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?;
+ let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
+ dequantize_mul_mat_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
+ } else {
+ mul_mat_vec_via_q8_1(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
+ };
let out_shape = if with_batch {
vec![1, 1, nrows]
} else {
@@ -341,3 +424,60 @@ pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
dtype: T::DTYPE,
}))
}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ #[test]
+ fn cuda_quantize_q8_1() -> Result<()> {
+ let dev = CudaDevice::new(0)?;
+ let el = 256;
+ let el_padded = (el + MATRIX_ROW_PADDING - 1) / MATRIX_ROW_PADDING * MATRIX_ROW_PADDING;
+ let y_size_in_bytes =
+ el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
+ let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
+ let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
+ let y = dev.htod_sync_copy(&vs).w()?;
+ quantize_q8_1(&y.slice(..), &mut y_q8_1, el, &dev)?;
+ Ok(())
+ }
+
+ #[test]
+ fn cuda_mmv_q8_1() -> Result<()> {
+ let dev = CudaDevice::new(0)?;
+ let ncols = 256;
+ let vs: Vec<f32> = (0..ncols).map(|v| v as f32).collect();
+ let y = dev.htod_sync_copy(&vs).w()?;
+ let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?;
+ xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
+ let cuda_storage = mul_mat_vec_via_q8_1(
+ &xs.data,
+ &y.slice(..),
+ /* dtype */ GgmlDType::Q4_0,
+ /* ncols */ ncols,
+ /* nrows */ 1,
+ &dev,
+ )?;
+ let vs = cuda_storage.as_cuda_slice::<f32>()?;
+ let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
+ assert_eq!(vs.len(), 1);
+ // for n = 255, n.(n+1).(2n+1) / 6 = 5559680
+ // Q8 means 1/256 precision.
+ assert_eq!(vs[0], 5561664.5);
+
+ let cuda_storage = dequantize_mul_mat_vec(
+ &xs.data,
+ &y.slice(..),
+ /* dtype */ GgmlDType::Q4_0,
+ /* ncols */ ncols,
+ /* nrows */ 1,
+ &dev,
+ )?;
+ let vs = cuda_storage.as_cuda_slice::<f32>()?;
+ let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
+ assert_eq!(vs.len(), 1);
+ assert_eq!(vs[0], 5561851.0);
+ Ok(())
+ }
+}