diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-02-25 18:11:47 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-25 18:11:47 +0100 |
commit | 2f22afd80ef6bc3e0ac7f6d55e4a4dc4dd480190 (patch) | |
tree | a0fca7887e011d5c8fc75c10c6fb2fd7d90d56cb /candle-core | |
parent | 8d04f70f4d1bd67c42fb7d63e7031d49cf780a61 (diff) | |
download | candle-2f22afd80ef6bc3e0ac7f6d55e4a4dc4dd480190.tar.gz candle-2f22afd80ef6bc3e0ac7f6d55e4a4dc4dd480190.tar.bz2 candle-2f22afd80ef6bc3e0ac7f6d55e4a4dc4dd480190.zip |
Cuda acceleration for quantized model. (#1754)
* Boilerplate for the quantized cuda support.
* More basic cuda support.
* More cuda quantization (quantize on cpu for now).
* Add the dequantization bit.
* Start adding some dedicated cuda kernels from llama.cpp.
* Move the kernel code.
* Start interfacing with the kernel.
* Tweak the kernel launch params.
* Bugfix for quantized metal.
* Fix some clippy lints.
* Tweak the launch parameters.
* Tweak cuda basics to perform a quantized matmul.
* Perform the dequantization on the cpu + use cublas for matmul.
* Add the dequantization kernel.
* Test the qmatmul.
* More kernels.
* Matmul-vec kernel.
* Add a couple kernels.
* More dequantization kernels.
Diffstat (limited to 'candle-core')
-rw-r--r-- | candle-core/examples/cuda_basics.rs | 39 | ||||
-rw-r--r-- | candle-core/src/metal_backend.rs | 10 | ||||
-rw-r--r-- | candle-core/src/quantized/cuda.rs | 321 | ||||
-rw-r--r-- | candle-core/src/quantized/dummy_cuda.rs | 50 | ||||
-rw-r--r-- | candle-core/src/quantized/dummy_metal.rs | 7 | ||||
-rw-r--r-- | candle-core/src/quantized/ggml_file.rs | 11 | ||||
-rw-r--r-- | candle-core/src/quantized/metal.rs | 53 | ||||
-rw-r--r-- | candle-core/src/quantized/mod.rs | 36 |
8 files changed, 458 insertions, 69 deletions
diff --git a/candle-core/examples/cuda_basics.rs b/candle-core/examples/cuda_basics.rs index ad207461..6e078a6e 100644 --- a/candle-core/examples/cuda_basics.rs +++ b/candle-core/examples/cuda_basics.rs @@ -5,25 +5,32 @@ extern crate accelerate_src; extern crate intel_mkl_src; use anyhow::Result; -use candle_core::{Device, Tensor}; +use candle_core::{Device, Module, Tensor}; + +use candle_core::quantized::{QMatMul, QTensor}; fn main() -> Result<()> { let device = Device::new_cuda(0)?; - let in_t = Tensor::rand(-1f32, 1f32, (1, 3, 12, 7), &device)?; - let k_t = Tensor::rand(-1f32, 1f32, (6, 3, 1, 1), &device)?; - let out_t = in_t.conv2d(&k_t, 0, 1, 1, 1)?; - println!("{out_t}"); - let in_t = in_t.to_device(&Device::Cpu)?; - let k_t = k_t.to_device(&Device::Cpu)?; - let out_t2 = in_t.conv2d(&k_t, 0, 1, 1, 1)?; - let diff = (out_t.to_device(&Device::Cpu)? - out_t2)? - .sqr()? - .sum_all()?; - println!("{diff}"); + let q = Tensor::randn(0f32, 1.0, (72, 32), &device)?; + let q_cpu = q.to_device(&Device::Cpu)?; + let q = QTensor::quantize(&q, candle_core::quantized::GgmlDType::Q4_0)?; + let q = QMatMul::from_qtensor(q)?; + let x = Tensor::randn(0f32, 1.0, (5, 32), &device)?; + let res_q_cuda = q.forward(&x)?; + println!("{res_q_cuda}"); - let t = Tensor::randn(0f32, 1f32, (2, 4, 96, 96), &device)?; - let w = Tensor::randn(0f32, 1f32, (320, 4, 3, 3), &device)?; - let res = t.conv2d(&w, 1, 1, 1, 1)?; - println!("{res:?}"); + let q_cpu = QTensor::quantize(&q_cpu, candle_core::quantized::GgmlDType::Q4_0)?; + let q_cpu_tensor = q_cpu.dequantize(&Device::Cpu)?; + let q_cpu = QMatMul::from_qtensor(q_cpu)?; + let x_cpu = x.to_device(&Device::Cpu)?; + let res_q_cpu = q_cpu.forward(&x_cpu)?; + println!("{res_q_cpu}"); + + let res_mm = x_cpu.matmul(&q_cpu_tensor.t()?)?; + let diff = (res_mm - res_q_cuda.to_device(&Device::Cpu))? + .abs()? + .flatten_all()? + .max(0)?; + println!("{diff}"); Ok(()) } diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index c19d7c56..959f0f31 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -827,9 +827,9 @@ impl BackendStorage for MetalStorage { layout.start_offset() * self.dtype.size_in_bytes(), ), &t.buffer, - (&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()), + (t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()), &f.buffer, - (&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()), + (f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()), &buffer, ) .map_err(MetalError::from)?; @@ -1264,7 +1264,7 @@ impl BackendStorage for MetalStorage { let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger; let length = (src_l.shape().elem_count() * self.dtype.size_in_bytes()) as NSUInteger; let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger; - blit.copy_from_buffer(&self.buffer, src_offset, &dst.buffer(), dst_offset, length); + blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length); blit.end_encoding(); } else { let src_shape = src_l.shape(); @@ -1636,7 +1636,7 @@ impl BackendDevice for MetalDevice { min as f32, max as f32, shape.elem_count(), - &*self.seed.lock().unwrap(), + &self.seed.lock().unwrap(), &buffer, ) .map_err(MetalError::from)?; @@ -1667,7 +1667,7 @@ impl BackendDevice for MetalDevice { mean as f32, stddev as f32, shape.elem_count(), - &*self.seed.lock().unwrap(), + &self.seed.lock().unwrap(), &buffer, ) .map_err(MetalError::from)?; diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs new file mode 100644 index 00000000..a2fc6655 --- /dev/null +++ b/candle-core/src/quantized/cuda.rs @@ -0,0 +1,321 @@ +use super::{GgmlDType, QStorage}; +use crate::{backend::BackendDevice, cuda_backend::WrapErr}; +use crate::{CudaDevice, CudaStorage, Result}; + +use cudarc::driver::{CudaSlice, DeviceSlice}; + +pub struct QCudaStorage { + data: CudaSlice<u8>, + dtype: GgmlDType, + device: CudaDevice, +} + +pub const WARP_SIZE: usize = 32; +pub const MMQ_X_Q4_0_AMPERE: usize = 4; +pub const MMQ_Y_Q4_0_AMPERE: usize = 32; +pub const NWARPS_Q4_0_AMPERE: usize = 4; +pub const GGML_CUDA_MMV_X: usize = 32; +pub const GGML_CUDA_MMV_Y: usize = 1; + +fn dequantize( + data: &CudaSlice<u8>, + dtype: GgmlDType, + elem_count: usize, + dev: &CudaDevice, +) -> Result<CudaStorage> { + use cudarc::driver::LaunchAsync; + + let (kernel_name, is_k) = match dtype { + GgmlDType::Q4_0 => ("dequantize_block_q4_0", false), + GgmlDType::Q4_1 => ("dequantize_block_q4_1", false), + GgmlDType::Q5_0 => ("dequantize_block_q5_0", false), + GgmlDType::Q5_1 => ("dequantize_block_q5_1", false), + GgmlDType::Q8_0 => ("dequantize_block_q8_0", false), + GgmlDType::Q2K => ("dequantize_block_q2_K", true), + GgmlDType::Q3K => ("dequantize_block_q3_K", true), + GgmlDType::Q4K => ("dequantize_block_q4_K", true), + GgmlDType::Q5K => ("dequantize_block_q5_K", true), + GgmlDType::Q6K => ("dequantize_block_q6_K", true), + _ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"), + }; + let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?; + let dst = dev.alloc_zeros::<f32>(elem_count).w()?; + let nb = (elem_count + 255) / 256; + let cfg = cudarc::driver::LaunchConfig { + grid_dim: (nb as u32, 1, 1), + block_dim: (32, 1, 1), + shared_mem_bytes: 0, + }; + + if is_k { + let params = (data, &dst); + unsafe { func.launch(cfg, params) }.w()?; + } else { + let nb32 = elem_count / 32; + let params = (data, &dst, nb32 as i32); + unsafe { func.launch(cfg, params) }.w()?; + } + Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone())) +} + +fn dequantize_mut_mal_vec( + data: &CudaSlice<u8>, + y: &cudarc::driver::CudaView<f32>, + dtype: GgmlDType, + ncols: usize, + nrows: usize, + dev: &CudaDevice, +) -> Result<CudaStorage> { + use cudarc::driver::LaunchAsync; + + let kernel_name = match dtype { + GgmlDType::Q4_0 => "dequantize_mul_mat_vec_q4_0_cuda", + GgmlDType::Q4_1 => "dequantize_mul_mat_vec_q4_1_cuda", + GgmlDType::Q5_0 => "dequantize_mul_mat_vec_q5_0_cuda", + GgmlDType::Q5_1 => "dequantize_mul_mat_vec_q5_1_cuda", + GgmlDType::Q8_0 => "dequantize_mul_mat_vec_q8_0_cuda", + GgmlDType::Q2K => "dequantize_mul_mat_vec_q2_k", + GgmlDType::Q3K => "dequantize_mul_mat_vec_q3_k", + GgmlDType::Q4K => "dequantize_mul_mat_vec_q4_k", + GgmlDType::Q5K => "dequantize_mul_mat_vec_q5_k", + GgmlDType::Q6K => "dequantize_mul_mat_vec_q6_k", + _ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"), + }; + let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?; + let dst = dev.alloc_zeros::<f32>(nrows).w()?; + let block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + let cfg = cudarc::driver::LaunchConfig { + grid_dim: (block_num_y as u32, 1, 1), + block_dim: (WARP_SIZE as u32, GGML_CUDA_MMV_Y as u32, 1), + shared_mem_bytes: 0, + }; + + let params = (data, y, &dst, ncols as i32, nrows as i32); + unsafe { func.launch(cfg, params) }.w()?; + Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone())) +} + +impl QCudaStorage { + pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> { + let size_in_bytes = el_count * dtype.type_size() / dtype.block_size(); + let data = device.alloc_zeros::<u8>(size_in_bytes).w()?; + Ok(QCudaStorage { + data, + device: device.clone(), + dtype, + }) + } + + pub fn dtype(&self) -> GgmlDType { + self.dtype + } + + pub fn device(&self) -> &CudaDevice { + &self.device + } + + pub fn dequantize(&self, elem_count: usize) -> Result<CudaStorage> { + let fast_kernel = match self.dtype { + GgmlDType::Q4_0 + | GgmlDType::Q4_1 + | GgmlDType::Q5_0 + | GgmlDType::Q5_1 + | GgmlDType::Q8_0 + | GgmlDType::Q2K + | GgmlDType::Q3K + | GgmlDType::Q4K + | GgmlDType::Q5K + | GgmlDType::Q6K => true, + _ => false, + }; + if fast_kernel { + return dequantize(&self.data, self.dtype, elem_count, self.device()); + } + // Run the dequantization on cpu. + use crate::quantized::k_quants::GgmlType; + + let buffer = self.device.dtoh_sync_copy(&self.data).w()?; + let mut out = vec![0.0; elem_count]; + let block_len = elem_count / self.dtype.block_size(); + match self.dtype { + GgmlDType::F32 => { + let slice = + unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const f32, block_len) }; + out.copy_from_slice(slice) + } + GgmlDType::F16 => { + let vec: Vec<half::f16> = read_to_vec(&buffer, block_len); + half::f16::to_float(&vec, &mut out)?; + } + GgmlDType::Q4_0 => { + let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?; + } + GgmlDType::Q4_1 => { + let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?; + } + GgmlDType::Q5_0 => { + let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?; + } + GgmlDType::Q5_1 => { + let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?; + } + GgmlDType::Q8_0 => { + let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?; + } + GgmlDType::Q8_1 => { + let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?; + } + GgmlDType::Q2K => { + let vec: Vec<crate::quantized::BlockQ2K> = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ2K::to_float(&vec, &mut out)?; + } + GgmlDType::Q3K => { + let vec: Vec<crate::quantized::BlockQ3K> = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ3K::to_float(&vec, &mut out)?; + } + GgmlDType::Q4K => { + let vec: Vec<crate::quantized::BlockQ4K> = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ4K::to_float(&vec, &mut out)?; + } + GgmlDType::Q5K => { + let vec: Vec<crate::quantized::BlockQ5K> = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ5K::to_float(&vec, &mut out)?; + } + GgmlDType::Q6K => { + let vec: Vec<crate::quantized::BlockQ6K> = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ6K::to_float(&vec, &mut out)?; + } + GgmlDType::Q8K => { + let vec: Vec<crate::quantized::BlockQ8K> = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ8K::to_float(&vec, &mut out)?; + } + } + + self.device + .storage_from_cpu_storage(&crate::CpuStorage::F32(out)) + } + + pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> { + // Run the quantization on cpu. + let src = match &src.slice { + crate::cuda_backend::CudaStorageSlice::F32(data) => { + self.device.dtoh_sync_copy(data).w()? + } + _ => crate::bail!("only f32 can be quantized"), + }; + let src_len = src.len(); + let src = crate::Storage::Cpu(crate::CpuStorage::F32(src)); + let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?; + qcpu_storage.quantize(&src)?; + let data = qcpu_storage.data()?; + let data = self.device.htod_sync_copy(data.as_ref()).w()?; + self.data = data; + Ok(()) + } + + pub fn storage_size_in_bytes(&self) -> usize { + self.data.len() + } + + pub fn fwd( + &self, + self_shape: &crate::Shape, + storage: &CudaStorage, + layout: &crate::Layout, + ) -> Result<(CudaStorage, crate::Shape)> { + let dmmv = match layout.shape().dims() { + [1, 1, _] | [1, _] => true, + _ => false, + }; + if dmmv { + self.dequantize_matmul_vec(self_shape, storage, layout) + } else { + self.dequantize_matmul(self_shape, storage, layout) + } + } +} + +impl QCudaStorage { + fn dequantize_matmul_vec( + &self, + self_shape: &crate::Shape, + rhs: &CudaStorage, + rhs_l: &crate::Layout, + ) -> Result<(CudaStorage, crate::Shape)> { + let (nrows, ncols) = self_shape.dims2()?; + let rhs = rhs.as_cuda_slice::<f32>()?; + let rhs = match rhs_l.contiguous_offsets() { + Some((o1, o2)) => rhs.slice(o1..o2), + None => Err(crate::Error::RequiresContiguous { op: "dmmv" }.bt())?, + }; + let (with_batch, k) = match rhs_l.shape().dims() { + [1, 1, k] => (true, k), + [1, k] => (false, k), + _ => crate::bail!("unexpected rhs shape in dmmv {:?}", rhs_l.shape()), + }; + if ncols != *k { + crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape()) + } + + let out = + dequantize_mut_mal_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?; + let out_shape = if with_batch { + vec![1, 1, nrows] + } else { + vec![1, nrows] + }; + Ok((out, out_shape.into())) + } + + fn dequantize_matmul( + &self, + self_shape: &crate::Shape, + storage: &CudaStorage, + layout: &crate::Layout, + ) -> Result<(CudaStorage, crate::Shape)> { + use crate::backend::BackendStorage; + let (n, k) = self_shape.dims2()?; + let (b, m, k2) = match layout.shape().dims() { + &[b, m, k2] => (b, m, k2), + &[m, k2] => (1, m, k2), + s => crate::bail!("unexpected shape for input {s:?}"), + }; + if k2 != k { + crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", layout.shape()) + } + + let data_f32 = self.dequantize(n * k)?; + let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0); + let out = storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?; + let mut out_shape = layout.shape().dims().to_vec(); + out_shape.pop(); + out_shape.push(n); + Ok((out, out_shape.into())) + } +} + +fn read_to_vec<T: Clone>(buffer: &[u8], n: usize) -> Vec<T> { + let slice = unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const T, n) }; + slice.to_vec() +} + +pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>( + device: &CudaDevice, + data: &[T], +) -> Result<super::QStorage> { + let data = unsafe { + std::slice::from_raw_parts(data.as_ptr() as *const u8, core::mem::size_of_val(data)) + }; + let data = device.htod_sync_copy(data).w()?; + Ok(QStorage::Cuda(QCudaStorage { + data, + device: device.clone(), + dtype: T::DTYPE, + })) +} diff --git a/candle-core/src/quantized/dummy_cuda.rs b/candle-core/src/quantized/dummy_cuda.rs new file mode 100644 index 00000000..598c5cd1 --- /dev/null +++ b/candle-core/src/quantized/dummy_cuda.rs @@ -0,0 +1,50 @@ +#![allow(unused)] +use super::GgmlDType; +use crate::{CudaDevice, CudaStorage, Error, Result}; + +pub struct QCudaStorage { + dtype: GgmlDType, + device: CudaDevice, +} + +impl QCudaStorage { + pub fn zeros(_: &CudaDevice, _: usize, _: GgmlDType) -> Result<Self> { + Err(Error::NotCompiledWithCudaSupport) + } + + pub fn dtype(&self) -> GgmlDType { + self.dtype + } + + pub fn device(&self) -> &CudaDevice { + &self.device + } + + pub fn dequantize(&self, _elem_count: usize) -> Result<CudaStorage> { + Err(Error::NotCompiledWithCudaSupport) + } + + pub fn quantize(&mut self, _src: &CudaStorage) -> Result<()> { + Err(Error::NotCompiledWithCudaSupport) + } + + pub fn storage_size_in_bytes(&self) -> usize { + 0 + } + + pub fn fwd( + &self, + _self_shape: &crate::Shape, + _storage: &CudaStorage, + _layout: &crate::Layout, + ) -> Result<(CudaStorage, crate::Shape)> { + Err(Error::NotCompiledWithCudaSupport) + } +} + +pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>( + _device: &CudaDevice, + _data: &[T], +) -> Result<super::QStorage> { + Err(Error::NotCompiledWithCudaSupport) +} diff --git a/candle-core/src/quantized/dummy_metal.rs b/candle-core/src/quantized/dummy_metal.rs index 96f91c50..520d0ed4 100644 --- a/candle-core/src/quantized/dummy_metal.rs +++ b/candle-core/src/quantized/dummy_metal.rs @@ -41,3 +41,10 @@ impl QMetalStorage { Err(Error::NotCompiledWithMetalSupport) } } + +pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>( + _device: &MetalDevice, + _data: &[T], +) -> Result<super::QStorage> { + Err(Error::NotCompiledWithMetalSupport) +} diff --git a/candle-core/src/quantized/ggml_file.rs b/candle-core/src/quantized/ggml_file.rs index e6f5791c..99200bbd 100644 --- a/candle-core/src/quantized/ggml_file.rs +++ b/candle-core/src/quantized/ggml_file.rs @@ -1,7 +1,5 @@ //! Support for the GGML file format. -#[cfg(feature = "metal")] -use super::metal::load_quantized_metal; use super::{k_quants, GgmlDType, QStorage}; use crate::{Device, Result}; use byteorder::{LittleEndian, ReadBytesExt}; @@ -130,13 +128,8 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>( let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) }; let data: QStorage = match device { Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())), - #[cfg(feature = "metal")] - Device::Metal(metal) => load_quantized_metal(metal, data)?, - #[cfg(not(feature = "metal"))] - Device::Metal(_metal) => { - crate::bail!("Metal backend requires `metal` feature") - } - device => unimplemented!("Implement quantized tensor for device {device:?}"), + Device::Metal(metal) => super::metal::load_quantized(metal, data)?, + Device::Cuda(cuda) => super::cuda::load_quantized(cuda, data)?, }; super::QTensor::new(data, dims) } diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs index 5cdfe6ab..af1cf369 100644 --- a/candle-core/src/quantized/metal.rs +++ b/candle-core/src/quantized/metal.rs @@ -34,6 +34,8 @@ impl QMetalStorage { } pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> { + use crate::quantized::k_quants::GgmlType; + let buffer = self.device.new_buffer_managed(self.buffer.length())?; let command_buffer = self.device.command_buffer()?; command_buffer.set_label("to_cpu"); @@ -43,81 +45,62 @@ impl QMetalStorage { blit.end_encoding(); self.device.wait_until_completed()?; let mut out = vec![0.0; elem_count]; + let block_len = elem_count / self.dtype.block_size(); match self.dtype { GgmlDType::F32 => { - let vec: Vec<f32> = read_to_vec(&buffer, elem_count); - use crate::quantized::k_quants::GgmlType; + let vec: Vec<f32> = read_to_vec(&buffer, block_len); f32::to_float(&vec, &mut out)?; } GgmlDType::F16 => { - let vec: Vec<half::f16> = read_to_vec(&buffer, elem_count); - use crate::quantized::k_quants::GgmlType; + let vec: Vec<half::f16> = read_to_vec(&buffer, block_len); half::f16::to_float(&vec, &mut out)?; } GgmlDType::Q4_0 => { - let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, elem_count); - use crate::quantized::k_quants::GgmlType; + let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, block_len); crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?; } GgmlDType::Q4_1 => { - let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, elem_count); - use crate::quantized::k_quants::GgmlType; + let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, block_len); crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?; } GgmlDType::Q5_0 => { - let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, elem_count); - use crate::quantized::k_quants::GgmlType; + let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, block_len); crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?; } GgmlDType::Q5_1 => { - let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, elem_count); - use crate::quantized::k_quants::GgmlType; + let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, block_len); crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?; } GgmlDType::Q8_0 => { - let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, elem_count); - use crate::quantized::k_quants::GgmlType; + let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, block_len); crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?; } GgmlDType::Q8_1 => { - let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, elem_count); - use crate::quantized::k_quants::GgmlType; + let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, block_len); crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?; } GgmlDType::Q2K => { - let vec: Vec<crate::quantized::BlockQ2K> = - read_to_vec(&buffer, elem_count / self.dtype.block_size()); - use crate::quantized::k_quants::GgmlType; + let vec: Vec<crate::quantized::BlockQ2K> = read_to_vec(&buffer, block_len); crate::quantized::BlockQ2K::to_float(&vec, &mut out)?; } GgmlDType::Q3K => { - let vec: Vec<crate::quantized::BlockQ3K> = - read_to_vec(&buffer, elem_count / self.dtype.block_size()); - use crate::quantized::k_quants::GgmlType; + let vec: Vec<crate::quantized::BlockQ3K> = read_to_vec(&buffer, block_len); crate::quantized::BlockQ3K::to_float(&vec, &mut out)?; } GgmlDType::Q4K => { - let vec: Vec<crate::quantized::BlockQ4K> = - read_to_vec(&buffer, elem_count / self.dtype.block_size()); - use crate::quantized::k_quants::GgmlType; + let vec: Vec<crate::quantized::BlockQ4K> = read_to_vec(&buffer, block_len); crate::quantized::BlockQ4K::to_float(&vec, &mut out)?; } GgmlDType::Q5K => { - let vec: Vec<crate::quantized::BlockQ5K> = - read_to_vec(&buffer, elem_count / self.dtype.block_size()); - use crate::quantized::k_quants::GgmlType; + let vec: Vec<crate::quantized::BlockQ5K> = read_to_vec(&buffer, block_len); crate::quantized::BlockQ5K::to_float(&vec, &mut out)?; } GgmlDType::Q6K => { - let vec: Vec<crate::quantized::BlockQ6K> = - read_to_vec(&buffer, elem_count / self.dtype.block_size()); - use crate::quantized::k_quants::GgmlType; + let vec: Vec<crate::quantized::BlockQ6K> = read_to_vec(&buffer, block_len); crate::quantized::BlockQ6K::to_float(&vec, &mut out)?; } GgmlDType::Q8K => { - let vec: Vec<crate::quantized::BlockQ8K> = - read_to_vec(&buffer, elem_count / self.dtype.block_size()); - use crate::quantized::k_quants::GgmlType; + let vec: Vec<crate::quantized::BlockQ8K> = read_to_vec(&buffer, block_len); crate::quantized::BlockQ8K::to_float(&vec, &mut out)?; } } @@ -192,7 +175,7 @@ impl QMetalStorage { } } -pub fn load_quantized_metal<T: super::GgmlType + Send + Sync + 'static>( +pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>( device: &MetalDevice, data: &[T], ) -> Result<QStorage> { diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index d14b2dc2..f7abcd93 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -4,6 +4,7 @@ use std::borrow::Cow; #[cfg(target_feature = "avx")] pub mod avx; +mod dummy_cuda; mod dummy_metal; pub mod ggml_file; pub mod gguf_file; @@ -14,6 +15,13 @@ pub mod metal; mod metal { pub use super::dummy_metal::*; } +#[cfg(feature = "cuda")] +pub mod cuda; +#[cfg(not(feature = "cuda"))] +mod cuda { + pub use super::dummy_cuda::*; +} + #[cfg(target_feature = "neon")] pub mod neon; #[cfg(target_feature = "simd128")] @@ -39,8 +47,9 @@ impl Device { let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?; Ok(QStorage::Metal(storage)) } - Device::Cuda(_cuda) => { - crate::bail!("Cuda ggml quantization not supported"); + Device::Cuda(cuda) => { + let storage = cuda::QCudaStorage::zeros(cuda, elem_count, dtype)?; + Ok(QStorage::Cuda(storage)) } } } @@ -49,6 +58,7 @@ impl Device { pub enum QStorage { Cpu(Box<dyn QuantizedType>), Metal(metal::QMetalStorage), + Cuda(cuda::QCudaStorage), } impl QStorage { @@ -56,6 +66,7 @@ impl QStorage { match self { QStorage::Cpu(storage) => storage.block_size(), QStorage::Metal(storage) => storage.dtype().block_size(), + QStorage::Cuda(storage) => storage.dtype().block_size(), } } @@ -63,6 +74,7 @@ impl QStorage { match self { QStorage::Cpu(storage) => storage.dtype(), QStorage::Metal(storage) => storage.dtype(), + QStorage::Cuda(storage) => storage.dtype(), } } @@ -70,6 +82,7 @@ impl QStorage { match self { QStorage::Cpu(_storage) => Device::Cpu, QStorage::Metal(storage) => Device::Metal(storage.device().clone()), + QStorage::Cuda(storage) => Device::Cuda(storage.device().clone()), } } @@ -77,6 +90,7 @@ impl QStorage { match self { QStorage::Cpu(storage) => storage.storage_size_in_bytes(), QStorage::Metal(storage) => storage.storage_size_in_bytes(), + QStorage::Cuda(storage) => storage.storage_size_in_bytes(), } } @@ -86,6 +100,7 @@ impl QStorage { storage.from_float(src.as_slice::<f32>()?)?; } (QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?, + (QStorage::Cuda(storage), Storage::Cuda(src)) => storage.quantize(src)?, _ => crate::bail!("Invalid dequantize storage locations do not match"), } Ok(()) @@ -95,6 +110,7 @@ impl QStorage { match self { QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)), QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)), + QStorage::Cuda(storage) => Ok(Storage::Cuda(storage.dequantize(elem_count)?)), } } @@ -106,7 +122,7 @@ impl QStorage { let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) }; Ok(Cow::from(data)) } - QStorage::Metal(_storage) => { + QStorage::Metal(_) | QStorage::Cuda(_) => { crate::bail!("not implemented"); } } @@ -424,7 +440,7 @@ impl crate::CustomOp1 for QTensor { #[allow(clippy::infallible_destructuring_match)] let self_storage = match &self.storage { QStorage::Cpu(storage) => storage, - QStorage::Metal(_) => crate::bail!("Invalid storage"), + QStorage::Metal(_) | QStorage::Cuda(_) => crate::bail!("Invalid storage"), }; let slice = storage.as_slice::<f32>()?; let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()]; @@ -444,6 +460,18 @@ impl crate::CustomOp1 for QTensor { }; self_storage.fwd(&self.shape, storage, layout) } + + fn cuda_fwd( + &self, + storage: &crate::CudaStorage, + layout: &crate::Layout, + ) -> Result<(crate::CudaStorage, Shape)> { + let self_storage = match &self.storage { + QStorage::Cuda(cuda) => cuda, + _ => unreachable!("Cannot call cuda matmul on non cuda QTensor"), + }; + self_storage.fwd(&self.shape, storage, layout) + } } impl crate::Module for QMatMul { |