summaryrefslogtreecommitdiff
path: root/candle-core/src/quantized
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-28 20:05:05 +0200
committerGitHub <noreply@github.com>2024-04-28 20:05:05 +0200
commiteb26e2467eb4cb5ca507324cc3245600c104f219 (patch)
tree7aa8fead605a786c38d0b6d2835342240e80c9a2 /candle-core/src/quantized
parentc68ed8963fb6fc842f20d84baa07ff97b56aedb4 (diff)
downloadcandle-eb26e2467eb4cb5ca507324cc3245600c104f219.tar.gz
candle-eb26e2467eb4cb5ca507324cc3245600c104f219.tar.bz2
candle-eb26e2467eb4cb5ca507324cc3245600c104f219.zip
Add the cuda dequantize f16 kernels. (#2137)
* Add the cuda dequantize f16 kernels. * Expose the cuda kernels. * Add some testing + fix. * Test the other cases too. * A few more tests. * Add an environment variable to enable the dequantize f16 + matmul behavior.
Diffstat (limited to 'candle-core/src/quantized')
-rw-r--r--candle-core/src/quantized/cuda.rs88
-rw-r--r--candle-core/src/quantized/dummy_cuda.rs4
-rw-r--r--candle-core/src/quantized/mod.rs47
3 files changed, 122 insertions, 17 deletions
diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs
index 5481ca3c..8e4884b2 100644
--- a/candle-core/src/quantized/cuda.rs
+++ b/candle-core/src/quantized/cuda.rs
@@ -2,6 +2,7 @@ use super::{GgmlDType, QStorage};
use crate::quantized::k_quants::GgmlType;
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
use crate::{CudaDevice, CudaStorage, Result};
+use half::f16;
use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};
@@ -59,7 +60,7 @@ fn quantize_q8_1(
Ok(())
}
-fn dequantize(
+fn dequantize_f32(
data: &CudaSlice<u8>,
dtype: GgmlDType,
elem_count: usize,
@@ -69,27 +70,27 @@ fn dequantize(
let nb = (elem_count + 255) / 256;
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
- GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32, nb),
- GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32, nb),
+ GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb),
+ GgmlDType::Q4_1 => ("dequantize_block_q4_1_f32", false, 32, nb),
GgmlDType::Q5_0 => (
- "dequantize_block_q5_0",
+ "dequantize_block_q5_0_f32",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
),
GgmlDType::Q5_1 => (
- "dequantize_block_q5_1",
+ "dequantize_block_q5_1_f32",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
),
- GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32, nb),
- GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64, nb),
- GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64, nb),
- GgmlDType::Q4K => ("dequantize_block_q4_K", true, 32, nb),
- GgmlDType::Q5K => ("dequantize_block_q5_K", true, 64, nb),
- GgmlDType::Q6K => ("dequantize_block_q6_K", true, 64, nb),
- GgmlDType::Q8K => ("dequantize_block_q8_K", true, 32, nb),
+ GgmlDType::Q8_0 => ("dequantize_block_q8_0_f32", false, 32, nb),
+ GgmlDType::Q2K => ("dequantize_block_q2_K_f32", true, 64, nb),
+ GgmlDType::Q3K => ("dequantize_block_q3_K_f32", true, 64, nb),
+ GgmlDType::Q4K => ("dequantize_block_q4_K_f32", true, 32, nb),
+ GgmlDType::Q5K => ("dequantize_block_q5_K_f32", true, 64, nb),
+ GgmlDType::Q6K => ("dequantize_block_q6_K_f32", true, 64, nb),
+ GgmlDType::Q8K => ("dequantize_block_q8_K_f32", true, 32, nb),
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
@@ -116,6 +117,63 @@ fn dequantize(
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
+fn dequantize_f16(
+ data: &CudaSlice<u8>,
+ dtype: GgmlDType,
+ elem_count: usize,
+ dev: &CudaDevice,
+) -> Result<CudaStorage> {
+ use cudarc::driver::LaunchAsync;
+
+ let nb = (elem_count + 255) / 256;
+ let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
+ GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb),
+ GgmlDType::Q4_1 => ("dequantize_block_q4_1_f16", false, 32, nb),
+ GgmlDType::Q5_0 => (
+ "dequantize_block_q5_0_f16",
+ false,
+ CUDA_DEQUANTIZE_BLOCK_SIZE,
+ ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
+ ),
+ GgmlDType::Q5_1 => (
+ "dequantize_block_q5_1_f16",
+ false,
+ CUDA_DEQUANTIZE_BLOCK_SIZE,
+ ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
+ ),
+ GgmlDType::Q8_0 => ("dequantize_block_q8_0_f16", false, 32, nb),
+ GgmlDType::Q2K => ("dequantize_block_q2_K_f16", true, 64, nb),
+ GgmlDType::Q3K => ("dequantize_block_q3_K_f16", true, 64, nb),
+ GgmlDType::Q4K => ("dequantize_block_q4_K_f16", true, 32, nb),
+ GgmlDType::Q5K => ("dequantize_block_q5_K_f16", true, 64, nb),
+ GgmlDType::Q6K => ("dequantize_block_q6_K_f16", true, 64, nb),
+ GgmlDType::Q8K => ("dequantize_block_q8_K_f16", true, 32, nb),
+ _ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
+ };
+ let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
+ let dst = unsafe { dev.alloc::<f16>(elem_count).w()? };
+ // See e.g.
+ // https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
+ let cfg = cudarc::driver::LaunchConfig {
+ grid_dim: (num_blocks as u32, 1, 1),
+ block_dim: (block_dim as u32, 1, 1),
+ shared_mem_bytes: 0,
+ };
+
+ if is_k {
+ let params = (data, &dst);
+ unsafe { func.launch(cfg, params) }.w()?;
+ } else {
+ let nb32 = match dtype {
+ GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
+ _ => elem_count / 32,
+ };
+ let params = (data, &dst, nb32 as i32);
+ unsafe { func.launch(cfg, params) }.w()?;
+ }
+ Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
+}
+
fn dequantize_mul_mat_vec(
data: &CudaSlice<u8>,
y: &CudaView<f32>,
@@ -341,7 +399,7 @@ impl QCudaStorage {
| GgmlDType::Q8K
);
if fast_kernel {
- return dequantize(&self.data, self.dtype, elem_count, self.device());
+ return dequantize_f32(&self.data, self.dtype, elem_count, self.device());
}
// Run the dequantization on cpu.
@@ -369,6 +427,10 @@ impl QCudaStorage {
.storage_from_cpu_storage(&crate::CpuStorage::F32(out))
}
+ pub fn dequantize_f16(&self, elem_count: usize) -> Result<CudaStorage> {
+ dequantize_f16(&self.data, self.dtype, elem_count, self.device())
+ }
+
pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
// Run the quantization on cpu.
let src = match &src.slice {
diff --git a/candle-core/src/quantized/dummy_cuda.rs b/candle-core/src/quantized/dummy_cuda.rs
index 598c5cd1..ca7b8120 100644
--- a/candle-core/src/quantized/dummy_cuda.rs
+++ b/candle-core/src/quantized/dummy_cuda.rs
@@ -24,6 +24,10 @@ impl QCudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
+ pub fn dequantize_f16(&self, _elem_count: usize) -> Result<CudaStorage> {
+ Err(Error::NotCompiledWithCudaSupport)
+ }
+
pub fn quantize(&mut self, _src: &CudaStorage) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs
index 47307f2e..e87072bb 100644
--- a/candle-core/src/quantized/mod.rs
+++ b/candle-core/src/quantized/mod.rs
@@ -1,4 +1,4 @@
-use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor};
+use crate::{CpuStorage, DType, Device, Result, Shape, Storage, Tensor};
use k_quants::*;
use std::borrow::Cow;
@@ -360,9 +360,24 @@ impl QTensor {
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
let storage = self.storage.dequantize(self.shape.elem_count())?;
let none = crate::op::BackpropOp::none();
- let is_variable = false;
- crate::tensor::from_storage(storage, self.shape.clone(), none, is_variable)
- .to_device(device)
+ crate::tensor::from_storage(storage, self.shape.clone(), none, false).to_device(device)
+ }
+
+ pub fn dequantize_f16(&self, device: &Device) -> Result<Tensor> {
+ // In the CUDA case, we have a specialized kernel as this can be useful for volta
+ // architectures. https://github.com/huggingface/candle/issues/2136
+ match &self.storage {
+ QStorage::Cuda(s) => {
+ let s = s.dequantize_f16(self.shape.elem_count())?;
+ let none = crate::op::BackpropOp::none();
+ crate::tensor::from_storage(Storage::Cuda(s), self.shape.clone(), none, false)
+ .to_device(device)
+ }
+ _ => {
+ let s = self.dequantize(device)?.to_dtype(crate::DType::F16)?;
+ Ok(s)
+ }
+ }
}
pub fn storage_size_in_bytes(&self) -> usize {
@@ -378,6 +393,7 @@ impl QTensor {
pub enum QMatMul {
QTensor(std::sync::Arc<QTensor>),
Tensor(Tensor),
+ TensorF16(Tensor),
}
thread_local! {
@@ -391,6 +407,17 @@ thread_local! {
}
}
+thread_local! {
+ static DEQUANTIZE_ALL_F16: bool = {
+ match std::env::var("CANDLE_DEQUANTIZE_ALL_F16") {
+ Ok(s) => {
+ !s.is_empty() && s != "0"
+ },
+ Err(_) => false,
+ }
+ }
+}
+
impl QMatMul {
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
let dequantize = match qtensor.dtype() {
@@ -400,6 +427,9 @@ impl QMatMul {
let t = if dequantize {
let tensor = qtensor.dequantize(&qtensor.device())?;
Self::Tensor(tensor)
+ } else if DEQUANTIZE_ALL_F16.with(|b| *b) {
+ let tensor = qtensor.dequantize_f16(&qtensor.device())?;
+ Self::TensorF16(tensor)
} else {
Self::QTensor(qtensor)
};
@@ -486,6 +516,15 @@ impl crate::Module for QMatMul {
};
xs.matmul(&w)
}
+ Self::TensorF16(w) => {
+ let in_dtype = xs.dtype();
+ let w = match *xs.dims() {
+ [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
+ [bsize, _, _] => w.broadcast_left(bsize)?.t()?,
+ _ => w.t()?,
+ };
+ xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
+ }
}
}
}