summaryrefslogtreecommitdiff
path: root/candle-core/src/quantized/cuda.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/quantized/cuda.rs')
-rw-r--r--candle-core/src/quantized/cuda.rs88
1 files changed, 75 insertions, 13 deletions
diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs
index 5481ca3c..8e4884b2 100644
--- a/candle-core/src/quantized/cuda.rs
+++ b/candle-core/src/quantized/cuda.rs
@@ -2,6 +2,7 @@ use super::{GgmlDType, QStorage};
use crate::quantized::k_quants::GgmlType;
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
use crate::{CudaDevice, CudaStorage, Result};
+use half::f16;
use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};
@@ -59,7 +60,7 @@ fn quantize_q8_1(
Ok(())
}
-fn dequantize(
+fn dequantize_f32(
data: &CudaSlice<u8>,
dtype: GgmlDType,
elem_count: usize,
@@ -69,27 +70,27 @@ fn dequantize(
let nb = (elem_count + 255) / 256;
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
- GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32, nb),
- GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32, nb),
+ GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb),
+ GgmlDType::Q4_1 => ("dequantize_block_q4_1_f32", false, 32, nb),
GgmlDType::Q5_0 => (
- "dequantize_block_q5_0",
+ "dequantize_block_q5_0_f32",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
),
GgmlDType::Q5_1 => (
- "dequantize_block_q5_1",
+ "dequantize_block_q5_1_f32",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
),
- GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32, nb),
- GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64, nb),
- GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64, nb),
- GgmlDType::Q4K => ("dequantize_block_q4_K", true, 32, nb),
- GgmlDType::Q5K => ("dequantize_block_q5_K", true, 64, nb),
- GgmlDType::Q6K => ("dequantize_block_q6_K", true, 64, nb),
- GgmlDType::Q8K => ("dequantize_block_q8_K", true, 32, nb),
+ GgmlDType::Q8_0 => ("dequantize_block_q8_0_f32", false, 32, nb),
+ GgmlDType::Q2K => ("dequantize_block_q2_K_f32", true, 64, nb),
+ GgmlDType::Q3K => ("dequantize_block_q3_K_f32", true, 64, nb),
+ GgmlDType::Q4K => ("dequantize_block_q4_K_f32", true, 32, nb),
+ GgmlDType::Q5K => ("dequantize_block_q5_K_f32", true, 64, nb),
+ GgmlDType::Q6K => ("dequantize_block_q6_K_f32", true, 64, nb),
+ GgmlDType::Q8K => ("dequantize_block_q8_K_f32", true, 32, nb),
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
@@ -116,6 +117,63 @@ fn dequantize(
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
+fn dequantize_f16(
+ data: &CudaSlice<u8>,
+ dtype: GgmlDType,
+ elem_count: usize,
+ dev: &CudaDevice,
+) -> Result<CudaStorage> {
+ use cudarc::driver::LaunchAsync;
+
+ let nb = (elem_count + 255) / 256;
+ let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
+ GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb),
+ GgmlDType::Q4_1 => ("dequantize_block_q4_1_f16", false, 32, nb),
+ GgmlDType::Q5_0 => (
+ "dequantize_block_q5_0_f16",
+ false,
+ CUDA_DEQUANTIZE_BLOCK_SIZE,
+ ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
+ ),
+ GgmlDType::Q5_1 => (
+ "dequantize_block_q5_1_f16",
+ false,
+ CUDA_DEQUANTIZE_BLOCK_SIZE,
+ ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
+ ),
+ GgmlDType::Q8_0 => ("dequantize_block_q8_0_f16", false, 32, nb),
+ GgmlDType::Q2K => ("dequantize_block_q2_K_f16", true, 64, nb),
+ GgmlDType::Q3K => ("dequantize_block_q3_K_f16", true, 64, nb),
+ GgmlDType::Q4K => ("dequantize_block_q4_K_f16", true, 32, nb),
+ GgmlDType::Q5K => ("dequantize_block_q5_K_f16", true, 64, nb),
+ GgmlDType::Q6K => ("dequantize_block_q6_K_f16", true, 64, nb),
+ GgmlDType::Q8K => ("dequantize_block_q8_K_f16", true, 32, nb),
+ _ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
+ };
+ let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
+ let dst = unsafe { dev.alloc::<f16>(elem_count).w()? };
+ // See e.g.
+ // https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
+ let cfg = cudarc::driver::LaunchConfig {
+ grid_dim: (num_blocks as u32, 1, 1),
+ block_dim: (block_dim as u32, 1, 1),
+ shared_mem_bytes: 0,
+ };
+
+ if is_k {
+ let params = (data, &dst);
+ unsafe { func.launch(cfg, params) }.w()?;
+ } else {
+ let nb32 = match dtype {
+ GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
+ _ => elem_count / 32,
+ };
+ let params = (data, &dst, nb32 as i32);
+ unsafe { func.launch(cfg, params) }.w()?;
+ }
+ Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
+}
+
fn dequantize_mul_mat_vec(
data: &CudaSlice<u8>,
y: &CudaView<f32>,
@@ -341,7 +399,7 @@ impl QCudaStorage {
| GgmlDType::Q8K
);
if fast_kernel {
- return dequantize(&self.data, self.dtype, elem_count, self.device());
+ return dequantize_f32(&self.data, self.dtype, elem_count, self.device());
}
// Run the dequantization on cpu.
@@ -369,6 +427,10 @@ impl QCudaStorage {
.storage_from_cpu_storage(&crate::CpuStorage::F32(out))
}
+ pub fn dequantize_f16(&self, elem_count: usize) -> Result<CudaStorage> {
+ dequantize_f16(&self.data, self.dtype, elem_count, self.device())
+ }
+
pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
// Run the quantization on cpu.
let src = match &src.slice {