diff options
author | Nicolas Patry <patry.nicolas@protonmail.com> | 2024-01-17 10:27:58 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-17 10:27:58 +0100 |
commit | 403680f17ddc086295fbaee316cbed22d97a519b (patch) | |
tree | 80dcffe6e929640e7f0ebfff3ba90410fd58992e /candle-core/src/quantized | |
parent | 5270224f407502b82fe90bc2622894ce3871b002 (diff) | |
download | candle-403680f17ddc086295fbaee316cbed22d97a519b.tar.gz candle-403680f17ddc086295fbaee316cbed22d97a519b.tar.bz2 candle-403680f17ddc086295fbaee316cbed22d97a519b.zip |
Quantized GGUF style (#1523)
* Metal quantized modifications proposal.
- Add a device param, wherever needed.
- Create new QMetal storage thing that implements QuantizedType.
- Update everywhere needed.
Fix Python.
Fixing examples.
Fix: fmt + clippy + stub.
Moving everything around.
Only missing the actual implems.
Fixing everything + adding dequantized kernels.
More work.
Fixing matmul.
Fmt + Clippy
Some clippy fixes.
Working state.
Q2K Metal -> Bugged (also present in GGML).
Q4K CPU -> Bugged (present previously, new test catch it).
Q5K CPU -> Bugged (present previously).
Q8_1 Both -> Never really implemented it seems
Q8K metal -> Never implemented in metal
Fixing Q2K bug (present in ggml).
* Cleanup.
* Fix the rebase.
* Removing the fences speeds everything up and *is* correct this time...
* Cleanup the fence.
* After rebase.
* Bad code removal.
* Rebase after phi2 merge + fix replit default to CPU.
* Making the CI happy.
* More happy tests.
---------
Co-authored-by: Nicolas Patry <nicolas@Nicolass-MacBook-Pro.local>
Diffstat (limited to 'candle-core/src/quantized')
-rw-r--r-- | candle-core/src/quantized/ggml_file.rs | 84 | ||||
-rw-r--r-- | candle-core/src/quantized/gguf_file.rs | 28 | ||||
-rw-r--r-- | candle-core/src/quantized/metal.rs | 153 | ||||
-rw-r--r-- | candle-core/src/quantized/mod.rs | 302 |
4 files changed, 485 insertions, 82 deletions
diff --git a/candle-core/src/quantized/ggml_file.rs b/candle-core/src/quantized/ggml_file.rs index 1dd3d9c0..38238580 100644 --- a/candle-core/src/quantized/ggml_file.rs +++ b/candle-core/src/quantized/ggml_file.rs @@ -1,7 +1,9 @@ //! Support for the GGML file format. -use super::{k_quants, GgmlDType}; -use crate::Result; +#[cfg(feature = "metal")] +use super::metal::load_quantized_metal; +use super::{k_quants, GgmlDType, QStorage}; +use crate::{Device, Result}; use byteorder::{LittleEndian, ReadBytesExt}; use std::collections::HashMap; @@ -121,11 +123,22 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>( raw_data: &[u8], size_in_bytes: usize, dims: Vec<usize>, + device: &Device, ) -> Result<super::QTensor> { let raw_data_ptr = raw_data.as_ptr(); let n_blocks = size_in_bytes / std::mem::size_of::<T>(); let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) }; - super::QTensor::new(data.to_vec(), dims) + let data: QStorage = match device { + Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())), + #[cfg(feature = "metal")] + Device::Metal(metal) => load_quantized_metal(metal, data)?, + #[cfg(not(feature = "metal"))] + Device::Metal(_metal) => { + crate::bail!("Metal backend requires `metal` feature") + } + device => unimplemented!("Implement quantized tensor for device {device:?}"), + }; + super::QTensor::new(data, dims) } /// Creates a [Tensor] from a raw GGML tensor. @@ -133,29 +146,50 @@ pub fn qtensor_from_ggml( ggml_dtype: GgmlDType, raw_data: &[u8], dims: Vec<usize>, + device: &Device, ) -> Result<super::QTensor> { let tensor_elems = dims.iter().product::<usize>(); - let blck_size = ggml_dtype.blck_size(); - if tensor_elems % blck_size != 0 { + let block_size = ggml_dtype.block_size(); + if tensor_elems % block_size != 0 { crate::bail!( - "the number of elements {tensor_elems} is not divisible by the block size {blck_size}" + "the number of elements {tensor_elems} is not divisible by the block size {block_size}" ) } - let size_in_bytes = tensor_elems / blck_size * ggml_dtype.type_size(); + let size_in_bytes = tensor_elems / block_size * ggml_dtype.type_size(); match ggml_dtype { - GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims), - GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims), - GgmlDType::Q4_0 => from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims), - GgmlDType::Q4_1 => from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims), - GgmlDType::Q5_0 => from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims), - GgmlDType::Q5_1 => from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims), - GgmlDType::Q8_0 => from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims), - GgmlDType::Q2K => from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims), - GgmlDType::Q3K => from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims), - GgmlDType::Q4K => from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims), - GgmlDType::Q5K => from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims), - GgmlDType::Q6K => from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims), + GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims, device), + GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims, device), + GgmlDType::Q4_0 => { + from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q4_1 => { + from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q5_0 => { + from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q5_1 => { + from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q8_0 => { + from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q2K => { + from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q3K => { + from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q4K => { + from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q5K => { + from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q6K => { + from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims, device) + } _ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"), } } @@ -163,6 +197,7 @@ pub fn qtensor_from_ggml( fn read_one_tensor<R: std::io::Seek + std::io::Read>( reader: &mut R, magic: VersionedMagic, + device: &Device, ) -> Result<(String, super::QTensor)> { let n_dims = reader.read_u32::<LittleEndian>()?; let name_len = reader.read_u32::<LittleEndian>()?; @@ -183,11 +218,11 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>( } let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>(); let tensor_elems = dims.iter().product::<usize>(); - let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size(); + let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.block_size(); // TODO: Mmap version to avoid copying the data around? let mut raw_data = vec![0u8; size_in_bytes]; reader.read_exact(&mut raw_data)?; - match qtensor_from_ggml(ggml_dtype, &raw_data, dims) { + match qtensor_from_ggml(ggml_dtype, &raw_data, dims, device) { Ok(tensor) => Ok((name, tensor)), Err(e) => crate::bail!("Error creating tensor {name}: {e}"), } @@ -201,7 +236,10 @@ pub struct Content { } impl Content { - pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Content> { + pub fn read<R: std::io::Seek + std::io::Read>( + reader: &mut R, + device: &Device, + ) -> Result<Content> { // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505 let last_position = reader.seek(std::io::SeekFrom::End(0))?; reader.seek(std::io::SeekFrom::Start(0))?; @@ -211,7 +249,7 @@ impl Content { let mut tensors = HashMap::new(); while reader.stream_position()? != last_position { - let (name, tensor) = read_one_tensor(reader, magic)?; + let (name, tensor) = read_one_tensor(reader, magic, device)?; tensors.insert(name, tensor); } Ok(Self { diff --git a/candle-core/src/quantized/gguf_file.rs b/candle-core/src/quantized/gguf_file.rs index 587ffc0f..b729d4a0 100644 --- a/candle-core/src/quantized/gguf_file.rs +++ b/candle-core/src/quantized/gguf_file.rs @@ -3,7 +3,7 @@ //! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md use super::{GgmlDType, QTensor}; -use crate::Result; +use crate::{Device, Result}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::collections::HashMap; @@ -59,19 +59,25 @@ impl TensorInfo { &self, reader: &mut R, tensor_data_offset: u64, + device: &Device, ) -> Result<QTensor> { let tensor_elems = self.shape.elem_count(); - let blck_size = self.ggml_dtype.blck_size(); - if tensor_elems % blck_size != 0 { + let block_size = self.ggml_dtype.block_size(); + if tensor_elems % block_size != 0 { crate::bail!( - "the number of elements {tensor_elems} is not divisible by the block size {blck_size}" + "the number of elements {tensor_elems} is not divisible by the block size {block_size}" ) } - let size_in_bytes = tensor_elems / blck_size * self.ggml_dtype.type_size(); + let size_in_bytes = tensor_elems / block_size * self.ggml_dtype.type_size(); let mut raw_data = vec![0u8; size_in_bytes]; reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?; reader.read_exact(&mut raw_data)?; - super::ggml_file::qtensor_from_ggml(self.ggml_dtype, &raw_data, self.shape.dims().to_vec()) + super::ggml_file::qtensor_from_ggml( + self.ggml_dtype, + &raw_data, + self.shape.dims().to_vec(), + device, + ) } } @@ -460,12 +466,13 @@ impl Content { &self, reader: &mut R, name: &str, + device: &Device, ) -> Result<QTensor> { let tensor_info = match self.tensor_infos.get(name) { Some(tensor_info) => tensor_info, None => crate::bail!("cannot find tensor info for {name}"), }; - tensor_info.read(reader, self.tensor_data_offset) + tensor_info.read(reader, self.tensor_data_offset, device) } } @@ -517,10 +524,9 @@ pub fn write<W: std::io::Seek + std::io::Write>( "internal error, unexpected current position {tensor_start_pos} {offset} {pos}" ) } - let data_ptr = tensor.as_ptr(); - let size_in_bytes = tensor.storage_size_in_bytes(); - let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) }; - w.write_all(data)?; + let data = tensor.data()?; + let size_in_bytes = data.len(); + w.write_all(&data)?; let padding = 31 - (31 + size_in_bytes) % 32; w.write_all(&vec![0u8; padding])?; } diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs new file mode 100644 index 00000000..fe57ce14 --- /dev/null +++ b/candle-core/src/quantized/metal.rs @@ -0,0 +1,153 @@ +use super::{GgmlDType, QStorage}; +use crate::{DType, MetalDevice, MetalStorage, Result}; +use metal::Buffer; +use std::sync::Arc; + +pub struct QMetalStorage { + dtype: GgmlDType, + device: MetalDevice, + buffer: Arc<Buffer>, +} + +impl QMetalStorage { + pub fn dtype(&self) -> GgmlDType { + self.dtype + } + + pub fn buffer(&self) -> &Buffer { + &self.buffer + } + + pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: GgmlDType) -> Self { + Self { + device, + buffer, + dtype, + } + } + + pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> { + let buffer = self.device.new_buffer_managed(self.buffer.length())?; + let command_buffer = self.device.command_buffer()?; + command_buffer.set_label("to_cpu"); + let blit = command_buffer.new_blit_command_encoder(); + blit.set_label("blit_to_cpu"); + blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); + blit.end_encoding(); + self.device.wait_until_completed()?; + let mut out = vec![0.0; elem_count]; + match self.dtype { + GgmlDType::F32 => { + let vec: Vec<f32> = read_to_vec(&buffer, elem_count); + use crate::quantized::k_quants::GgmlType; + f32::to_float(&vec, &mut out)?; + } + GgmlDType::F16 => { + let vec: Vec<half::f16> = read_to_vec(&buffer, elem_count); + use crate::quantized::k_quants::GgmlType; + half::f16::to_float(&vec, &mut out)?; + } + GgmlDType::Q4_0 => { + let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, elem_count); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?; + } + GgmlDType::Q4_1 => { + let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, elem_count); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?; + } + GgmlDType::Q5_0 => { + let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, elem_count); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?; + } + GgmlDType::Q5_1 => { + let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, elem_count); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?; + } + GgmlDType::Q8_0 => { + let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, elem_count); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?; + } + GgmlDType::Q8_1 => { + let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, elem_count); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?; + } + GgmlDType::Q2K => { + let vec: Vec<crate::quantized::BlockQ2K> = + read_to_vec(&buffer, elem_count / self.dtype.block_size()); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ2K::to_float(&vec, &mut out)?; + } + GgmlDType::Q3K => { + let vec: Vec<crate::quantized::BlockQ3K> = + read_to_vec(&buffer, elem_count / self.dtype.block_size()); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ3K::to_float(&vec, &mut out)?; + } + GgmlDType::Q4K => { + let vec: Vec<crate::quantized::BlockQ4K> = + read_to_vec(&buffer, elem_count / self.dtype.block_size()); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ4K::to_float(&vec, &mut out)?; + } + GgmlDType::Q5K => { + let vec: Vec<crate::quantized::BlockQ5K> = + read_to_vec(&buffer, elem_count / self.dtype.block_size()); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ5K::to_float(&vec, &mut out)?; + } + GgmlDType::Q6K => { + let vec: Vec<crate::quantized::BlockQ6K> = + read_to_vec(&buffer, elem_count / self.dtype.block_size()); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ6K::to_float(&vec, &mut out)?; + } + GgmlDType::Q8K => { + let vec: Vec<crate::quantized::BlockQ8K> = + read_to_vec(&buffer, elem_count / self.dtype.block_size()); + use crate::quantized::k_quants::GgmlType; + crate::quantized::BlockQ8K::to_float(&vec, &mut out)?; + } + } + + let buffer = self.device.new_buffer_with_data(&out)?; + Ok(MetalStorage::new(buffer, self.device.clone(), DType::F32)) + } + + pub fn quantize(&mut self, src: &MetalStorage) -> Result<()> { + // Quantization only happens on CPU for now. + let src = src.to_cpu::<f32>()?; + let elem_count = src.len(); + let src = crate::Storage::Cpu(crate::CpuStorage::F32(src)); + let mut qcpu_storage = crate::Device::Cpu.qzeros(elem_count, self.dtype)?; + qcpu_storage.quantize(&src)?; + let buffer = self.device.new_buffer_with_data(&qcpu_storage.data()?)?; + self.buffer = buffer; + Ok(()) + } +} + +pub fn load_quantized_metal<T: super::GgmlType + Send + Sync + 'static>( + device: &MetalDevice, + data: &[T], +) -> Result<QStorage> { + let buffer = device.new_buffer_with_data(data)?; + let device = device.clone(); + Ok(QStorage::Metal(QMetalStorage { + dtype: T::DTYPE, + device, + buffer, + })) +} + +fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> { + let ptr = buffer.contents() as *const T; + assert!(!ptr.is_null()); + let slice = unsafe { std::slice::from_raw_parts(ptr, n) }; + slice.to_vec() +} diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 043733ae..1dc5fe8f 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -1,23 +1,125 @@ -use crate::{Device, Result, Shape, Tensor}; +#[cfg(feature = "metal")] +use crate::{backend::BackendStorage, DType}; +use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor}; +use k_quants::*; +use std::borrow::Cow; #[cfg(target_feature = "avx")] pub mod avx; pub mod ggml_file; pub mod gguf_file; pub mod k_quants; +#[cfg(feature = "metal")] +pub mod metal; #[cfg(target_feature = "neon")] pub mod neon; #[cfg(target_feature = "simd128")] pub mod simd128; pub mod utils; +use half::f16; pub use k_quants::GgmlType; pub struct QTensor { - data: Box<dyn QuantizedType>, + storage: QStorage, shape: Shape, } +impl Device { + fn qzeros(&self, elem_count: usize, dtype: GgmlDType) -> Result<QStorage> { + match self { + Device::Cpu => { + let storage = dtype.cpu_zeros(elem_count); + Ok(QStorage::Cpu(storage)) + } + #[cfg(feature = "metal")] + Device::Metal(metal) => { + let size = elem_count * dtype.type_size() / dtype.block_size(); + let buffer = metal.allocate_zeros(size)?; + Ok(QStorage::Metal(metal::QMetalStorage::new( + buffer, + metal.clone(), + dtype, + ))) + } + #[cfg(not(feature = "metal"))] + Device::Metal(_metal) => { + crate::bail!("Metal feature not activated"); + } + Device::Cuda(_cuda) => { + crate::bail!("Cuda ggml quantization not supported"); + } + } + } +} + +pub enum QStorage { + Cpu(Box<dyn QuantizedType>), + #[cfg(feature = "metal")] + Metal(metal::QMetalStorage), +} + +impl QStorage { + fn block_size(&self) -> usize { + match self { + QStorage::Cpu(storage) => storage.block_size(), + #[cfg(feature = "metal")] + QStorage::Metal(storage) => storage.dtype().block_size(), + } + } + + fn dtype(&self) -> GgmlDType { + match self { + QStorage::Cpu(storage) => storage.dtype(), + #[cfg(feature = "metal")] + QStorage::Metal(storage) => storage.dtype(), + } + } + + fn size_in_bytes(&self) -> usize { + match self { + QStorage::Cpu(storage) => storage.storage_size_in_bytes(), + #[cfg(feature = "metal")] + QStorage::Metal(storage) => storage.buffer().length() as usize, + } + } + + fn quantize(&mut self, src: &Storage) -> Result<()> { + match (self, src) { + (QStorage::Cpu(storage), Storage::Cpu(src)) => { + storage.from_float(src.as_slice::<f32>()?)?; + } + #[cfg(feature = "metal")] + (QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?, + _ => crate::bail!("Invalid dequantize storage locations do not match"), + } + Ok(()) + } + + fn dequantize(&self, elem_count: usize) -> Result<Storage> { + match self { + QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)), + #[cfg(feature = "metal")] + QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)), + } + } + + fn data(&self) -> Result<Cow<[u8]>> { + match self { + QStorage::Cpu(storage) => { + let data_ptr = storage.as_ptr(); + let size_in_bytes = storage.storage_size_in_bytes(); + let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) }; + Ok(Cow::from(data)) + } + #[cfg(feature = "metal")] + QStorage::Metal(_storage) => { + crate::bail!("not implemented"); + } + } + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum GgmlDType { F32, @@ -77,6 +179,25 @@ impl GgmlDType { } } + /// The block dtype + pub fn cpu_zeros(&self, elem_count: usize) -> Box<dyn QuantizedType> { + match self { + Self::F32 => Box::new(vec![f32::zeros(); elem_count]), + Self::F16 => Box::new(vec![f16::zeros(); elem_count]), + Self::Q4_0 => Box::new(vec![BlockQ4_0::zeros(); elem_count / BlockQ4_0::BLCK_SIZE]), + Self::Q4_1 => Box::new(vec![BlockQ4_1::zeros(); elem_count / BlockQ4_1::BLCK_SIZE]), + Self::Q5_0 => Box::new(vec![BlockQ5_0::zeros(); elem_count / BlockQ5_0::BLCK_SIZE]), + Self::Q5_1 => Box::new(vec![BlockQ5_1::zeros(); elem_count / BlockQ5_1::BLCK_SIZE]), + Self::Q8_0 => Box::new(vec![BlockQ8_0::zeros(); elem_count / BlockQ8_0::BLCK_SIZE]), + Self::Q8_1 => Box::new(vec![BlockQ8_1::zeros(); elem_count / BlockQ8_1::BLCK_SIZE]), + Self::Q2K => Box::new(vec![BlockQ2K::zeros(); elem_count / BlockQ2K::BLCK_SIZE]), + Self::Q3K => Box::new(vec![BlockQ3K::zeros(); elem_count / BlockQ3K::BLCK_SIZE]), + Self::Q4K => Box::new(vec![BlockQ4K::zeros(); elem_count / BlockQ4K::BLCK_SIZE]), + Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]), + Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]), + Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]), + } + } /// The type size for blocks in bytes. pub fn type_size(&self) -> usize { use k_quants::*; @@ -100,7 +221,7 @@ impl GgmlDType { } /// The block size, i.e. the number of elements stored in each block. - pub fn blck_size(&self) -> usize { + pub fn block_size(&self) -> usize { match self { Self::F32 => 1, Self::F16 => 1, @@ -119,9 +240,13 @@ impl GgmlDType { pub trait QuantizedType: Send + Sync { fn dtype(&self) -> GgmlDType; fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>; - fn to_float(&self, ys: &mut [f32]) -> Result<()>; + fn dequantize(&self, elem_count: usize) -> Result<CpuStorage>; fn storage_size_in_bytes(&self) -> usize; fn as_ptr(&self) -> *const u8; + fn block_size(&self) -> usize; + #[allow(clippy::wrong_self_convention)] + fn from_float(&mut self, xs: &[f32]) -> Result<()>; + fn size(&self) -> usize; } impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> { @@ -129,12 +254,26 @@ impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> { k_quants::matmul(mkn, lhs, self.as_slice(), dst) } + fn size(&self) -> usize { + self.len() * core::mem::size_of::<T>() + } + + fn from_float(&mut self, xs: &[f32]) -> Result<()> { + T::from_float(xs, self) + } + fn dtype(&self) -> GgmlDType { T::DTYPE } - fn to_float(&self, ys: &mut [f32]) -> Result<()> { - T::to_float(self.as_slice(), ys) + fn block_size(&self) -> usize { + T::BLCK_SIZE + } + + fn dequantize(&self, elem_count: usize) -> Result<CpuStorage> { + let mut ys = vec![0.0f32; elem_count]; + T::to_float(self.as_slice(), &mut ys)?; + Ok(CpuStorage::F32(ys)) } fn storage_size_in_bytes(&self) -> usize { @@ -152,56 +291,49 @@ impl std::fmt::Debug for QTensor { } } -fn check_shape<T: k_quants::GgmlType>(shape: &Shape) -> Result<()> { +fn check_shape(shape: &Shape, block_size: usize) -> Result<()> { let dims = shape.dims(); if dims.is_empty() { crate::bail!("scalar tensor cannot be quantized {shape:?}") } - if dims[dims.len() - 1] % T::BLCK_SIZE != 0 { + if dims[dims.len() - 1] % block_size != 0 { crate::bail!( "quantized tensor must have their last dim divisible by block size {shape:?} {}", - T::BLCK_SIZE + block_size ) } Ok(()) } impl QTensor { - pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>( - data: Vec<T>, - shape: S, - ) -> Result<Self> { + pub fn new<S: Into<Shape>>(storage: QStorage, shape: S) -> Result<Self> { let shape = shape.into(); - check_shape::<T>(&shape)?; - Ok(Self { - data: Box::new(data), - shape, - }) + check_shape(&shape, storage.block_size())?; + Ok(Self { storage, shape }) } - pub fn quantize<T: k_quants::GgmlType + Send + Sync + 'static>(src: &Tensor) -> Result<Self> { + pub fn quantize(src: &Tensor, dtype: GgmlDType) -> Result<Self> { let shape = src.shape(); - check_shape::<T>(shape)?; - let src = src - .to_dtype(crate::DType::F32)? - .flatten_all()? - .to_vec1::<f32>()?; - if src.len() % T::BLCK_SIZE != 0 { + let block_size = dtype.block_size(); + check_shape(shape, block_size)?; + let src = src.to_dtype(crate::DType::F32)?.flatten_all()?; + let elem_count = shape.elem_count(); + if elem_count % block_size != 0 { crate::bail!( "tensor size ({shape:?}) is not divisible by block size {}", - T::BLCK_SIZE + block_size ) } - let mut data = vec![T::zeros(); src.len() / T::BLCK_SIZE]; - T::from_float(&src, &mut data)?; + let mut storage = src.device().qzeros(elem_count, dtype)?; + storage.quantize(&src.storage())?; Ok(Self { - data: Box::new(data), + storage, shape: shape.clone(), }) } pub fn dtype(&self) -> GgmlDType { - self.data.dtype() + self.storage.dtype() } pub fn rank(&self) -> usize { @@ -213,21 +345,19 @@ impl QTensor { } pub fn dequantize(&self, device: &Device) -> Result<Tensor> { - let mut f32_data = vec![0f32; self.shape.elem_count()]; - self.data.to_float(&mut f32_data)?; - Tensor::from_vec(f32_data, &self.shape, device) - } - - pub fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> { - self.data.matmul_t(mkn, lhs, dst) + let storage = self.storage.dequantize(self.shape.elem_count())?; + let none = crate::op::BackpropOp::none(); + let is_variable = false; + crate::tensor::from_storage(storage, self.shape.clone(), none, is_variable) + .to_device(device) } pub fn storage_size_in_bytes(&self) -> usize { - self.data.storage_size_in_bytes() + self.storage.size_in_bytes() } - pub fn as_ptr(&self) -> *const u8 { - self.data.as_ptr() + pub fn data(&self) -> Result<Cow<'_, [u8]>> { + self.storage.data() } } @@ -294,17 +424,93 @@ impl crate::CustomOp1 for QTensor { } dst_shape.push(n); let dst_shape = Shape::from(dst_shape); - let storage = storage.as_slice::<f32>()?; - let storage = - &storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()]; + #[allow(clippy::infallible_destructuring_match)] + let self_storage = match &self.storage { + QStorage::Cpu(storage) => storage, + #[cfg(feature = "metal")] + _ => crate::bail!("Invalid storage"), + }; + let slice = storage.as_slice::<f32>()?; + let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()]; let mut dst_storage = vec![0f32; dst_shape.elem_count()]; - self.matmul_t( - (dst_shape.elem_count() / n, k, n), - storage, - &mut dst_storage, - )?; + self_storage.matmul_t((dst_shape.elem_count() / n, k, n), slice, &mut dst_storage)?; Ok((crate::CpuStorage::F32(dst_storage), dst_shape)) } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + storage: &crate::MetalStorage, + layout: &crate::Layout, + ) -> Result<(crate::MetalStorage, Shape)> { + use crate::MetalError; + + if !layout.is_contiguous() { + crate::bail!("input tensor is not contiguous {layout:?}") + } + let src_shape = layout.shape(); + // self is transposed so n is first then k. + if src_shape.rank() < 2 { + crate::bail!("input tensor has only one dimension {layout:?}") + } + let (n, k) = self.shape.dims2()?; + let mut dst_shape = src_shape.dims().to_vec(); + + let (b, m) = match dst_shape.len() { + 3 => (dst_shape[0], dst_shape[1]), + 2 => (1, dst_shape[0]), + n => crate::bail!("Invalid rank {n} for quantized matmul metal"), + }; + let last_k = dst_shape.pop().unwrap(); + if last_k != k { + crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape) + } + dst_shape.push(n); + let dst_shape = Shape::from(dst_shape); + let device = storage.device().clone(); + let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?; + let (buffer, dtype) = match &self.storage { + QStorage::Metal(metal) => (metal.buffer(), metal.dtype()), + _ => unreachable!("Cannot call metal matmul on non metal QTensor"), + }; + let command_buffer = device.command_buffer()?; + candle_metal_kernels::call_quantized_matmul_t( + device.device(), + &command_buffer, + device.kernels(), + dtype.into(), + (b, m, n, k), + storage.buffer(), + layout.start_offset() * storage.dtype().size_in_bytes(), + buffer, + &dst, + ) + .map_err(MetalError::from)?; + let dst_storage = crate::MetalStorage::new(dst, device, DType::F32); + Ok((dst_storage, dst_shape)) + } +} + +#[cfg(feature = "metal")] +impl From<GgmlDType> for candle_metal_kernels::GgmlDType { + fn from(value: GgmlDType) -> Self { + match value { + GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0, + GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1, + GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0, + GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1, + GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0, + GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1, + GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K, + GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K, + GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K, + GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K, + GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K, + GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K, + GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16, + GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32, + } + } } impl crate::Module for QMatMul { |