summaryrefslogtreecommitdiff
path: root/candle-core/src/quantized
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-15 08:32:47 +0200
committerGitHub <noreply@github.com>2024-04-15 08:32:47 +0200
commitf7d5bf5b97071c5bb299084559992e4681fcf277 (patch)
tree86853362dd1556131fa9f981cf4abd58f1c058f8 /candle-core/src/quantized
parentc119600d6edbe02349b93cf08372409bffd4cf6a (diff)
downloadcandle-f7d5bf5b97071c5bb299084559992e4681fcf277.tar.gz
candle-f7d5bf5b97071c5bb299084559992e4681fcf277.tar.bz2
candle-f7d5bf5b97071c5bb299084559992e4681fcf277.zip
Faster kernels for quantized matmul on cuda (#2060)
* Hook the quantized matmul cuda kernels. * Add a (currently broken) test. * Kernel fixes. * Fix by transposing the rhs matrix. * Add the q4-1 kernels. * Proper block sizes. * More details in the tests.
Diffstat (limited to 'candle-core/src/quantized')
-rw-r--r--candle-core/src/quantized/cuda.rs143
1 files changed, 137 insertions, 6 deletions
diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs
index 07f8c13e..487431f6 100644
--- a/candle-core/src/quantized/cuda.rs
+++ b/candle-core/src/quantized/cuda.rs
@@ -40,6 +40,7 @@ fn quantize_q8_1(
src: &CudaView<f32>,
dst: &mut CudaSlice<u8>,
elem_count: usize,
+ ky: usize,
dev: &CudaDevice,
) -> Result<()> {
use cudarc::driver::LaunchAsync;
@@ -49,7 +50,7 @@ fn quantize_q8_1(
let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE);
let func = dev.get_or_load_func("quantize_q8_1", candle_kernels::QUANTIZED)?;
let cfg = cudarc::driver::LaunchConfig {
- grid_dim: (num_blocks as u32, 1, 1),
+ grid_dim: (num_blocks as u32, ky as u32, 1),
block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1),
shared_mem_bytes: 0,
};
@@ -180,7 +181,7 @@ fn mul_mat_vec_via_q8_1(
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
let y_size_in_bytes = ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
- quantize_q8_1(y, &mut y_q8_1, ncols, dev)?;
+ quantize_q8_1(y, &mut y_q8_1, ncols, 1, dev)?;
let kernel_name = match dtype {
GgmlDType::Q4_0 => "mul_mat_vec_q4_0_q8_1_cuda",
@@ -216,6 +217,75 @@ fn mul_mat_vec_via_q8_1(
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
+fn mul_mat_via_q8_1(
+ data: &CudaSlice<u8>,
+ y: &CudaView<f32>,
+ dtype: GgmlDType,
+ x_rows: usize,
+ x_cols: usize,
+ y_rows: usize,
+ y_cols: usize,
+ dev: &CudaDevice,
+) -> Result<CudaStorage> {
+ use cudarc::driver::LaunchAsync;
+
+ let data_elems = data.len() / dtype.type_size() * dtype.block_size();
+ if data_elems < x_rows * x_cols {
+ crate::bail!("unexpected lhs size {}, {x_rows} {x_cols}", data_elems)
+ }
+ if y.len() != y_rows * y_cols {
+ crate::bail!("unexpected y size {}, {y_rows} {y_cols}", y.len())
+ }
+ if x_cols != y_rows {
+ crate::bail!("unexpected x/y size {x_rows} {x_cols} {y_rows} {y_cols}")
+ }
+ let k = x_cols;
+ // Start by quantizing y
+ let k_padded = pad(k, MATRIX_ROW_PADDING);
+ let y_size_in_bytes =
+ k_padded * y_rows * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
+ let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
+ quantize_q8_1(y, &mut y_q8_1, k, y_cols, dev)?;
+
+ let (kernel_name, mmq_x, mmq_y) = match dtype {
+ GgmlDType::Q4_0 => ("mul_mat_q4_0", 64, 128),
+ GgmlDType::Q4_1 => ("mul_mat_q4_1", 64, 128),
+ GgmlDType::Q5_0 => ("mul_mat_q5_0", 128, 64),
+ GgmlDType::Q5_1 => ("mul_mat_q5_1", 128, 64),
+ GgmlDType::Q8_0 => ("mul_mat_q8_0", 128, 64),
+ GgmlDType::Q2K => ("mul_mat_q2_K", 64, 128),
+ GgmlDType::Q3K => ("mul_mat_q3_K", 128, 128),
+ GgmlDType::Q4K => ("mul_mat_q4_K", 64, 128),
+ GgmlDType::Q5K => ("mul_mat_q5_K", 64, 128),
+ GgmlDType::Q6K => ("mul_mat_q6_K", 64, 64),
+ _ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
+ };
+ let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
+ let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols).w()? };
+ let cfg = cudarc::driver::LaunchConfig {
+ grid_dim: (
+ ceil_div(x_rows, mmq_y) as u32,
+ ceil_div(y_cols, mmq_x) as u32,
+ 1,
+ ),
+ block_dim: (WARP_SIZE as u32, 4, 1),
+ shared_mem_bytes: 0,
+ };
+
+ let params = (
+ /* vx */ data,
+ /* vy */ &y_q8_1,
+ /* dst */ &dst,
+ /* ncols_x */ x_cols as i32,
+ /* nrows_x */ x_rows as i32,
+ /* ncols_y */ y_cols as i32,
+ /* nrows_y */ k_padded as i32,
+ /* nrows_dst */ x_rows as i32,
+ );
+ unsafe { func.launch(cfg, params) }.w()?;
+ Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
+}
+
impl QCudaStorage {
pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> {
let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size();
@@ -373,9 +443,30 @@ impl QCudaStorage {
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", layout.shape())
}
- let data_f32 = self.dequantize(n * k)?;
- let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?;
- let out = storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?;
+ let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
+ let data_f32 = self.dequantize(n * k)?;
+ let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?;
+ storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?
+ } else {
+ let storage = storage.as_cuda_slice::<f32>()?;
+ let storage = match layout.contiguous_offsets() {
+ Some((o1, o2)) => storage.slice(o1..o2),
+ None => Err(crate::Error::RequiresContiguous {
+ op: "quantized-matmul",
+ }
+ .bt())?,
+ };
+ mul_mat_via_q8_1(
+ &self.data,
+ &storage,
+ self.dtype,
+ /* x_rows */ n,
+ /* x_cols */ k,
+ /* y_rows */ k,
+ /* y_cols */ m,
+ self.device(),
+ )?
+ };
let mut out_shape = layout.shape().dims().to_vec();
out_shape.pop();
out_shape.push(n);
@@ -412,7 +503,7 @@ mod test {
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
let y = dev.htod_sync_copy(&vs).w()?;
- quantize_q8_1(&y.slice(..), &mut y_q8_1, el, &dev)?;
+ quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?;
Ok(())
}
@@ -453,4 +544,44 @@ mod test {
assert_eq!(vs[0], 5561851.0);
Ok(())
}
+
+ #[test]
+ fn cuda_mm_q8_1() -> Result<()> {
+ let dev = CudaDevice::new(0)?;
+ let ncols = 256;
+ let vs: Vec<f32> = (0..ncols * 4).map(|v| v as f32 / 4.).collect();
+ let y = dev.htod_sync_copy(&vs).w()?;
+ let mut xs = QCudaStorage::zeros(&dev, ncols * 4, GgmlDType::Q4_0)?;
+ xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
+ let cuda_storage = mul_mat_via_q8_1(
+ &xs.data,
+ &y.slice(..),
+ /* dtype */ GgmlDType::Q4_0,
+ /* x_rows */ 4,
+ /* x_cols */ ncols,
+ /* y_rows */ ncols,
+ /* y_cols */ 4,
+ &dev,
+ )?;
+ let vs = cuda_storage.as_cuda_slice::<f32>()?;
+ let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
+
+ /*
+ x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256)
+ x @ x.t() / 16
+ tensor([[ 347480.0000, 869720.0000, 1391960.0000, 1914200.0000],
+ [ 869720.0000, 2440536.0000, 4011352.0000, 5582166.5000],
+ [ 1391960.0000, 4011352.0000, 6630742.0000, 9250132.0000],
+ [ 1914200.0000, 5582166.5000, 9250132.0000, 12918099.0000]])
+ */
+ assert_eq!(vs.len(), 16);
+ assert_eq!(vs[0], 347604.0);
+ assert_eq!(vs[1], 888153.06);
+ assert_eq!(vs[4], 869780.7);
+ assert_eq!(vs[5], 2483145.0);
+ assert_eq!(vs[11], 9407368.0);
+ assert_eq!(vs[14], 9470856.0);
+ assert_eq!(vs[15], 13138824.0);
+ Ok(())
+ }
}