diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-04-28 20:05:05 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-28 20:05:05 +0200 |
commit | eb26e2467eb4cb5ca507324cc3245600c104f219 (patch) | |
tree | 7aa8fead605a786c38d0b6d2835342240e80c9a2 /candle-core/src/quantized | |
parent | c68ed8963fb6fc842f20d84baa07ff97b56aedb4 (diff) | |
download | candle-eb26e2467eb4cb5ca507324cc3245600c104f219.tar.gz candle-eb26e2467eb4cb5ca507324cc3245600c104f219.tar.bz2 candle-eb26e2467eb4cb5ca507324cc3245600c104f219.zip |
Add the cuda dequantize f16 kernels. (#2137)
* Add the cuda dequantize f16 kernels.
* Expose the cuda kernels.
* Add some testing + fix.
* Test the other cases too.
* A few more tests.
* Add an environment variable to enable the dequantize f16 + matmul behavior.
Diffstat (limited to 'candle-core/src/quantized')
-rw-r--r-- | candle-core/src/quantized/cuda.rs | 88 | ||||
-rw-r--r-- | candle-core/src/quantized/dummy_cuda.rs | 4 | ||||
-rw-r--r-- | candle-core/src/quantized/mod.rs | 47 |
3 files changed, 122 insertions, 17 deletions
diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index 5481ca3c..8e4884b2 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -2,6 +2,7 @@ use super::{GgmlDType, QStorage}; use crate::quantized::k_quants::GgmlType; use crate::{backend::BackendDevice, cuda_backend::WrapErr}; use crate::{CudaDevice, CudaStorage, Result}; +use half::f16; use cudarc::driver::{CudaSlice, CudaView, DeviceSlice}; @@ -59,7 +60,7 @@ fn quantize_q8_1( Ok(()) } -fn dequantize( +fn dequantize_f32( data: &CudaSlice<u8>, dtype: GgmlDType, elem_count: usize, @@ -69,27 +70,27 @@ fn dequantize( let nb = (elem_count + 255) / 256; let (kernel_name, is_k, block_dim, num_blocks) = match dtype { - GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32, nb), - GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32, nb), + GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb), + GgmlDType::Q4_1 => ("dequantize_block_q4_1_f32", false, 32, nb), GgmlDType::Q5_0 => ( - "dequantize_block_q5_0", + "dequantize_block_q5_0_f32", false, CUDA_DEQUANTIZE_BLOCK_SIZE, ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE), ), GgmlDType::Q5_1 => ( - "dequantize_block_q5_1", + "dequantize_block_q5_1_f32", false, CUDA_DEQUANTIZE_BLOCK_SIZE, ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE), ), - GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32, nb), - GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64, nb), - GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64, nb), - GgmlDType::Q4K => ("dequantize_block_q4_K", true, 32, nb), - GgmlDType::Q5K => ("dequantize_block_q5_K", true, 64, nb), - GgmlDType::Q6K => ("dequantize_block_q6_K", true, 64, nb), - GgmlDType::Q8K => ("dequantize_block_q8_K", true, 32, nb), + GgmlDType::Q8_0 => ("dequantize_block_q8_0_f32", false, 32, nb), + GgmlDType::Q2K => ("dequantize_block_q2_K_f32", true, 64, nb), + GgmlDType::Q3K => ("dequantize_block_q3_K_f32", true, 64, nb), + GgmlDType::Q4K => ("dequantize_block_q4_K_f32", true, 32, nb), + GgmlDType::Q5K => ("dequantize_block_q5_K_f32", true, 64, nb), + GgmlDType::Q6K => ("dequantize_block_q6_K_f32", true, 64, nb), + GgmlDType::Q8K => ("dequantize_block_q8_K_f32", true, 32, nb), _ => crate::bail!("unsupported dtype for dequantize {dtype:?}"), }; let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?; @@ -116,6 +117,63 @@ fn dequantize( Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone())) } +fn dequantize_f16( + data: &CudaSlice<u8>, + dtype: GgmlDType, + elem_count: usize, + dev: &CudaDevice, +) -> Result<CudaStorage> { + use cudarc::driver::LaunchAsync; + + let nb = (elem_count + 255) / 256; + let (kernel_name, is_k, block_dim, num_blocks) = match dtype { + GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb), + GgmlDType::Q4_1 => ("dequantize_block_q4_1_f16", false, 32, nb), + GgmlDType::Q5_0 => ( + "dequantize_block_q5_0_f16", + false, + CUDA_DEQUANTIZE_BLOCK_SIZE, + ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE), + ), + GgmlDType::Q5_1 => ( + "dequantize_block_q5_1_f16", + false, + CUDA_DEQUANTIZE_BLOCK_SIZE, + ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE), + ), + GgmlDType::Q8_0 => ("dequantize_block_q8_0_f16", false, 32, nb), + GgmlDType::Q2K => ("dequantize_block_q2_K_f16", true, 64, nb), + GgmlDType::Q3K => ("dequantize_block_q3_K_f16", true, 64, nb), + GgmlDType::Q4K => ("dequantize_block_q4_K_f16", true, 32, nb), + GgmlDType::Q5K => ("dequantize_block_q5_K_f16", true, 64, nb), + GgmlDType::Q6K => ("dequantize_block_q6_K_f16", true, 64, nb), + GgmlDType::Q8K => ("dequantize_block_q8_K_f16", true, 32, nb), + _ => crate::bail!("unsupported dtype for dequantize {dtype:?}"), + }; + let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?; + let dst = unsafe { dev.alloc::<f16>(elem_count).w()? }; + // See e.g. + // https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270 + let cfg = cudarc::driver::LaunchConfig { + grid_dim: (num_blocks as u32, 1, 1), + block_dim: (block_dim as u32, 1, 1), + shared_mem_bytes: 0, + }; + + if is_k { + let params = (data, &dst); + unsafe { func.launch(cfg, params) }.w()?; + } else { + let nb32 = match dtype { + GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count, + _ => elem_count / 32, + }; + let params = (data, &dst, nb32 as i32); + unsafe { func.launch(cfg, params) }.w()?; + } + Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone())) +} + fn dequantize_mul_mat_vec( data: &CudaSlice<u8>, y: &CudaView<f32>, @@ -341,7 +399,7 @@ impl QCudaStorage { | GgmlDType::Q8K ); if fast_kernel { - return dequantize(&self.data, self.dtype, elem_count, self.device()); + return dequantize_f32(&self.data, self.dtype, elem_count, self.device()); } // Run the dequantization on cpu. @@ -369,6 +427,10 @@ impl QCudaStorage { .storage_from_cpu_storage(&crate::CpuStorage::F32(out)) } + pub fn dequantize_f16(&self, elem_count: usize) -> Result<CudaStorage> { + dequantize_f16(&self.data, self.dtype, elem_count, self.device()) + } + pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> { // Run the quantization on cpu. let src = match &src.slice { diff --git a/candle-core/src/quantized/dummy_cuda.rs b/candle-core/src/quantized/dummy_cuda.rs index 598c5cd1..ca7b8120 100644 --- a/candle-core/src/quantized/dummy_cuda.rs +++ b/candle-core/src/quantized/dummy_cuda.rs @@ -24,6 +24,10 @@ impl QCudaStorage { Err(Error::NotCompiledWithCudaSupport) } + pub fn dequantize_f16(&self, _elem_count: usize) -> Result<CudaStorage> { + Err(Error::NotCompiledWithCudaSupport) + } + pub fn quantize(&mut self, _src: &CudaStorage) -> Result<()> { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 47307f2e..e87072bb 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -1,4 +1,4 @@ -use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor}; +use crate::{CpuStorage, DType, Device, Result, Shape, Storage, Tensor}; use k_quants::*; use std::borrow::Cow; @@ -360,9 +360,24 @@ impl QTensor { pub fn dequantize(&self, device: &Device) -> Result<Tensor> { let storage = self.storage.dequantize(self.shape.elem_count())?; let none = crate::op::BackpropOp::none(); - let is_variable = false; - crate::tensor::from_storage(storage, self.shape.clone(), none, is_variable) - .to_device(device) + crate::tensor::from_storage(storage, self.shape.clone(), none, false).to_device(device) + } + + pub fn dequantize_f16(&self, device: &Device) -> Result<Tensor> { + // In the CUDA case, we have a specialized kernel as this can be useful for volta + // architectures. https://github.com/huggingface/candle/issues/2136 + match &self.storage { + QStorage::Cuda(s) => { + let s = s.dequantize_f16(self.shape.elem_count())?; + let none = crate::op::BackpropOp::none(); + crate::tensor::from_storage(Storage::Cuda(s), self.shape.clone(), none, false) + .to_device(device) + } + _ => { + let s = self.dequantize(device)?.to_dtype(crate::DType::F16)?; + Ok(s) + } + } } pub fn storage_size_in_bytes(&self) -> usize { @@ -378,6 +393,7 @@ impl QTensor { pub enum QMatMul { QTensor(std::sync::Arc<QTensor>), Tensor(Tensor), + TensorF16(Tensor), } thread_local! { @@ -391,6 +407,17 @@ thread_local! { } } +thread_local! { + static DEQUANTIZE_ALL_F16: bool = { + match std::env::var("CANDLE_DEQUANTIZE_ALL_F16") { + Ok(s) => { + !s.is_empty() && s != "0" + }, + Err(_) => false, + } + } +} + impl QMatMul { pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> { let dequantize = match qtensor.dtype() { @@ -400,6 +427,9 @@ impl QMatMul { let t = if dequantize { let tensor = qtensor.dequantize(&qtensor.device())?; Self::Tensor(tensor) + } else if DEQUANTIZE_ALL_F16.with(|b| *b) { + let tensor = qtensor.dequantize_f16(&qtensor.device())?; + Self::TensorF16(tensor) } else { Self::QTensor(qtensor) }; @@ -486,6 +516,15 @@ impl crate::Module for QMatMul { }; xs.matmul(&w) } + Self::TensorF16(w) => { + let in_dtype = xs.dtype(); + let w = match *xs.dims() { + [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?, + [bsize, _, _] => w.broadcast_left(bsize)?.t()?, + _ => w.t()?, + }; + xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype) + } } } } |