summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2024-01-17 10:27:58 +0100
committerGitHub <noreply@github.com>2024-01-17 10:27:58 +0100
commit403680f17ddc086295fbaee316cbed22d97a519b (patch)
tree80dcffe6e929640e7f0ebfff3ba90410fd58992e
parent5270224f407502b82fe90bc2622894ce3871b002 (diff)
downloadcandle-403680f17ddc086295fbaee316cbed22d97a519b.tar.gz
candle-403680f17ddc086295fbaee316cbed22d97a519b.tar.bz2
candle-403680f17ddc086295fbaee316cbed22d97a519b.zip
Quantized GGUF style (#1523)
* Metal quantized modifications proposal. - Add a device param, wherever needed. - Create new QMetal storage thing that implements QuantizedType. - Update everywhere needed. Fix Python. Fixing examples. Fix: fmt + clippy + stub. Moving everything around. Only missing the actual implems. Fixing everything + adding dequantized kernels. More work. Fixing matmul. Fmt + Clippy Some clippy fixes. Working state. Q2K Metal -> Bugged (also present in GGML). Q4K CPU -> Bugged (present previously, new test catch it). Q5K CPU -> Bugged (present previously). Q8_1 Both -> Never really implemented it seems Q8K metal -> Never implemented in metal Fixing Q2K bug (present in ggml). * Cleanup. * Fix the rebase. * Removing the fences speeds everything up and *is* correct this time... * Cleanup the fence. * After rebase. * Bad code removal. * Rebase after phi2 merge + fix replit default to CPU. * Making the CI happy. * More happy tests. --------- Co-authored-by: Nicolas Patry <nicolas@Nicolass-MacBook-Pro.local>
-rw-r--r--candle-core/examples/tensor-tools.rs122
-rw-r--r--candle-core/src/metal_backend.rs112
-rw-r--r--candle-core/src/quantized/ggml_file.rs84
-rw-r--r--candle-core/src/quantized/gguf_file.rs28
-rw-r--r--candle-core/src/quantized/metal.rs153
-rw-r--r--candle-core/src/quantized/mod.rs302
-rw-r--r--candle-core/tests/quantized_tests.rs573
-rw-r--r--candle-examples/examples/blip/main.rs4
-rw-r--r--candle-examples/examples/llama2-c/main.rs8
-rw-r--r--candle-examples/examples/mistral/main.rs7
-rw-r--r--candle-examples/examples/phi/main.rs16
-rw-r--r--candle-examples/examples/quantized-t5/main.rs3
-rw-r--r--candle-examples/examples/quantized/main.rs16
-rw-r--r--candle-examples/examples/replit-code/main.rs13
-rw-r--r--candle-examples/examples/stable-lm/main.rs5
-rw-r--r--candle-examples/examples/whisper/main.rs6
-rw-r--r--candle-metal-kernels/src/lib.rs228
-rw-r--r--candle-metal-kernels/src/quantized.metal5107
-rw-r--r--candle-metal-kernels/src/tests.rs33
-rw-r--r--candle-metal-kernels/src/unary.metal2
-rw-r--r--candle-nn/examples/cpu_benchmarks.rs5
-rw-r--r--candle-pyo3/py_src/candle/utils/__init__.pyi8
-rw-r--r--candle-pyo3/src/lib.rs51
-rw-r--r--candle-transformers/src/models/quantized_llama.rs41
-rw-r--r--candle-transformers/src/models/quantized_mixformer.rs4
-rw-r--r--candle-transformers/src/quantized_var_builder.rs12
-rw-r--r--candle-wasm-examples/blip/src/bin/m.rs2
-rw-r--r--candle-wasm-examples/phi/src/bin/m.rs6
-rw-r--r--candle-wasm-examples/t5/src/bin/m-quantized.rs9
-rw-r--r--candle-wasm-examples/whisper/src/worker.rs1
-rw-r--r--candle-wasm-tests/tests/quantized_tests.rs2
31 files changed, 6447 insertions, 516 deletions
diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs
index 337021aa..eb6ceb1c 100644
--- a/candle-core/examples/tensor-tools.rs
+++ b/candle-core/examples/tensor-tools.rs
@@ -1,5 +1,5 @@
-use candle_core::quantized::{gguf_file, k_quants, QTensor};
-use candle_core::{Device, Result, Tensor};
+use candle_core::quantized::{gguf_file, GgmlDType, QTensor};
+use candle_core::{Device, Result};
use clap::{Parser, Subcommand, ValueEnum};
use rayon::prelude::*;
@@ -11,12 +11,7 @@ enum QuantizationMode {
}
impl QuantizationMode {
- fn quantize(
- &self,
- name: &str,
- tensor: QTensor,
- default: fn(&Tensor) -> Result<QTensor>,
- ) -> Result<QTensor> {
+ fn quantize(&self, name: &str, tensor: QTensor, dtype: GgmlDType) -> Result<QTensor> {
match self {
Self::Llama => {
// Same behavior as the llama.cpp quantization.
@@ -24,9 +19,9 @@ impl QuantizationMode {
if should_quantize {
let tensor = tensor.dequantize(&Device::Cpu)?;
if name == "output.weight" {
- QTensor::quantize::<k_quants::BlockQ6K>(&tensor)
+ QTensor::quantize(&tensor, GgmlDType::Q6K)
} else {
- default(&tensor)
+ QTensor::quantize(&tensor, dtype)
}
} else {
Ok(tensor)
@@ -60,6 +55,27 @@ enum Quantization {
F32,
}
+impl Quantization {
+ fn dtype(&self) -> GgmlDType {
+ match self {
+ Quantization::Q4_0 => GgmlDType::Q4_0,
+ Quantization::Q4_1 => GgmlDType::Q4_1,
+ Quantization::Q5_0 => GgmlDType::Q5_0,
+ Quantization::Q5_1 => GgmlDType::Q5_1,
+ Quantization::Q8_0 => GgmlDType::Q8_0,
+ Quantization::Q8_1 => GgmlDType::Q8_1,
+ Quantization::Q2k => GgmlDType::Q2K,
+ Quantization::Q3k => GgmlDType::Q3K,
+ Quantization::Q4k => GgmlDType::Q4K,
+ Quantization::Q5k => GgmlDType::Q5K,
+ Quantization::Q6k => GgmlDType::Q6K,
+ Quantization::Q8k => GgmlDType::Q8K,
+ Quantization::F16 => GgmlDType::F16,
+ Quantization::F32 => GgmlDType::F32,
+ }
+ }
+}
+
#[derive(ValueEnum, Debug, Clone)]
enum Format {
Safetensors,
@@ -134,7 +150,12 @@ struct Args {
command: Command,
}
-fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> Result<()> {
+fn run_ls(
+ file: &std::path::PathBuf,
+ format: Option<Format>,
+ verbose: bool,
+ device: &Device,
+) -> Result<()> {
let format = match format {
Some(format) => format,
None => match Format::infer(file) {
@@ -200,7 +221,7 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
}
Format::Ggml => {
let mut file = std::fs::File::open(file)?;
- let content = candle_core::quantized::ggml_file::Content::read(&mut file)?;
+ let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?;
let mut tensors = content.tensors.into_iter().collect::<Vec<_>>();
tensors.sort_by(|a, b| a.0.cmp(&b.0));
for (name, qtensor) in tensors.iter() {
@@ -241,37 +262,8 @@ fn run_quantize_safetensors(
}
println!("tensors: {}", tensors.len());
- let quantize_fn = match q {
- Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
- Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
- Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
- Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
- Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
- Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
- Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
- Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
- Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
- Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
- Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
- Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
- Quantization::F16 => QTensor::quantize::<half::f16>,
- Quantization::F32 => QTensor::quantize::<f32>,
- };
- let block_size = match q {
- Quantization::Q4_0 => k_quants::QK4_0,
- Quantization::Q4_1 => k_quants::QK4_1,
- Quantization::Q5_0 => k_quants::QK5_0,
- Quantization::Q5_1 => k_quants::QK5_1,
- Quantization::Q8_0 => k_quants::QK8_0,
- Quantization::Q8_1 => k_quants::QK8_1,
- Quantization::Q2k
- | Quantization::Q3k
- | Quantization::Q4k
- | Quantization::Q5k
- | Quantization::Q6k
- | Quantization::Q8k => k_quants::QK_K,
- Quantization::F16 | Quantization::F32 => 1,
- };
+ let dtype = q.dtype();
+ let block_size = dtype.block_size();
let qtensors = tensors
.into_par_iter()
@@ -279,9 +271,9 @@ fn run_quantize_safetensors(
let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0;
println!(" quantizing {name} {tensor:?} {should_quantize}");
let tensor = if should_quantize {
- quantize_fn(&tensor)?
+ QTensor::quantize(&tensor, dtype)?
} else {
- QTensor::quantize::<f32>(&tensor)?
+ QTensor::quantize(&tensor, GgmlDType::F32)?
};
Ok((name, tensor))
})
@@ -294,13 +286,17 @@ fn run_quantize_safetensors(
Ok(())
}
-fn run_dequantize(in_file: std::path::PathBuf, out_file: std::path::PathBuf) -> Result<()> {
+fn run_dequantize(
+ in_file: std::path::PathBuf,
+ out_file: std::path::PathBuf,
+ device: &Device,
+) -> Result<()> {
let mut in_file = std::fs::File::open(in_file)?;
let content = gguf_file::Content::read(&mut in_file)?;
let mut tensors = std::collections::HashMap::new();
for (tensor_name, _) in content.tensor_infos.iter() {
- let tensor = content.tensor(&mut in_file, tensor_name)?;
- let tensor = tensor.dequantize(&Device::Cpu)?;
+ let tensor = content.tensor(&mut in_file, tensor_name, device)?;
+ let tensor = tensor.dequantize(device)?;
tensors.insert(tensor_name.to_string(), tensor);
}
candle_core::safetensors::save(&tensors, out_file)?;
@@ -312,6 +308,7 @@ fn run_quantize(
out_file: std::path::PathBuf,
q: Quantization,
qmode: QuantizationMode,
+ device: &Device,
) -> Result<()> {
if in_files.is_empty() {
candle_core::bail!("no specified input files")
@@ -337,31 +334,15 @@ fn run_quantize(
let content = gguf_file::Content::read(&mut in_)?;
println!("tensors: {}", content.tensor_infos.len());
- let quantize_fn = match q {
- Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
- Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
- Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
- Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
- Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
- Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
- Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
- Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
- Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
- Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
- Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
- Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
- Quantization::F16 => QTensor::quantize::<half::f16>,
- Quantization::F32 => QTensor::quantize::<f32>,
- };
-
+ let dtype = q.dtype();
let qtensors = content
.tensor_infos
.par_iter()
.map(|(name, _)| {
println!(" quantizing {name}");
let mut in_file = std::fs::File::open(&in_files[0])?;
- let tensor = content.tensor(&mut in_file, name)?;
- let tensor = qmode.quantize(name, tensor, quantize_fn)?;
+ let tensor = content.tensor(&mut in_file, name, device)?;
+ let tensor = qmode.quantize(name, tensor, dtype)?;
Ok((name, tensor))
})
.collect::<Result<Vec<_>>>()?;
@@ -381,6 +362,7 @@ fn run_quantize(
fn main() -> anyhow::Result<()> {
let args = Args::parse();
+ let device = Device::Cpu;
match args.command {
Command::Ls {
files,
@@ -392,7 +374,7 @@ fn main() -> anyhow::Result<()> {
if multiple_files {
println!("--- {file:?} ---");
}
- run_ls(file, format.clone(), verbose)?
+ run_ls(file, format.clone(), verbose, &device)?
}
}
Command::Quantize {
@@ -400,8 +382,8 @@ fn main() -> anyhow::Result<()> {
out_file,
quantization,
mode,
- } => run_quantize(&in_file, out_file, quantization, mode)?,
- Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file)?,
+ } => run_quantize(&in_file, out_file, quantization, mode, &device)?,
+ Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file, &device)?,
}
Ok(())
}
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs
index 5269a899..dc790ac9 100644
--- a/candle-core/src/metal_backend.rs
+++ b/candle-core/src/metal_backend.rs
@@ -84,13 +84,8 @@ pub struct MetalDevice {
command_buffer_index: Arc<RwLock<usize>>,
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
compute_per_buffer: usize,
- /// Every compute command encoder (and blit encoders) are defended with this Fence, forcing the
- /// execution order to be linear.
- /// It could be relaxed in some circumstances, by managing ourselves the dependencies in the
- /// compute graph.
- fence: metal::Fence,
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
- /// Heavily used by [`candle_metal_kernels`], both fences need to match
+ /// Heavily used by [`candle_metal_kernels`]
kernels: Arc<candle_metal_kernels::Kernels>,
/// Simple allocator struct.
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
@@ -221,10 +216,8 @@ impl MetalDevice {
let command_buffer = self.command_buffer()?;
command_buffer.set_label("with_data");
let blit = command_buffer.new_blit_command_encoder();
- blit.wait_for_fence(&self.fence);
blit.set_label("with_data_blit");
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
- blit.update_fence(&self.fence);
blit.end_encoding();
// This is necessary, for mmaped safetensors
@@ -238,6 +231,27 @@ impl MetalDevice {
Ok(real)
}
+ pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result<Arc<Buffer>> {
+ let buffer = self.allocate_buffer(
+ size_in_bytes as NSUInteger,
+ MTLResourceOptions::StorageModePrivate,
+ "allocate_zeros",
+ )?;
+ let command_buffer = self.command_buffer()?;
+ command_buffer.set_label("zeros");
+ let blit = command_buffer.new_blit_command_encoder();
+ blit.fill_buffer(
+ &buffer,
+ metal::NSRange {
+ location: 0,
+ length: buffer.length(),
+ },
+ 0,
+ );
+ blit.end_encoding();
+ Ok(buffer)
+ }
+
/// The critical allocator algorithm
fn allocate_buffer(
&self,
@@ -308,35 +322,14 @@ impl BackendStorage for MetalStorage {
}
fn to_cpu_storage(&self) -> Result<CpuStorage> {
- let length = self.buffer.length() as usize;
- let size = self.dtype.size_in_bytes();
- if length % size != 0 {
- crate::bail!(
- "The Metal buffer length is not aligned with dtype {:?}",
- self.dtype
- );
- }
- let buffer = self.device.new_buffer_managed(self.buffer.length())?;
- {
- let command_buffer = self.device.command_buffer()?;
- command_buffer.set_label("to_cpu");
- let blit = command_buffer.new_blit_command_encoder();
- blit.set_label("blit_to_cpu");
- blit.wait_for_fence(&self.device.fence);
- blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
- blit.update_fence(&self.device.fence);
- blit.end_encoding();
- }
- self.device.wait_until_completed()?;
-
match self.dtype {
- DType::U8 => Ok(CpuStorage::U8(read_to_vec(&buffer, length / size))),
- DType::U32 => Ok(CpuStorage::U32(read_to_vec(&buffer, length / size))),
- DType::I64 => Ok(CpuStorage::I64(read_to_vec(&buffer, length / size))),
- DType::F16 => Ok(CpuStorage::F16(read_to_vec(&buffer, length / size))),
- DType::BF16 => Ok(CpuStorage::BF16(read_to_vec(&buffer, length / size))),
- DType::F32 => Ok(CpuStorage::F32(read_to_vec(&buffer, length / size))),
- DType::F64 => Ok(CpuStorage::F64(read_to_vec(&buffer, length / size))),
+ DType::U8 => Ok(CpuStorage::U8(self.to_cpu()?)),
+ DType::U32 => Ok(CpuStorage::U32(self.to_cpu()?)),
+ DType::I64 => Ok(CpuStorage::I64(self.to_cpu()?)),
+ DType::F16 => Ok(CpuStorage::F16(self.to_cpu()?)),
+ DType::BF16 => Ok(CpuStorage::BF16(self.to_cpu()?)),
+ DType::F32 => Ok(CpuStorage::F32(self.to_cpu()?)),
+ DType::F64 => Ok(CpuStorage::F64(self.to_cpu()?)),
}
}
@@ -1264,7 +1257,7 @@ impl BackendStorage for MetalStorage {
let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger;
let length = (src_l.shape().elem_count() * self.dtype.size_in_bytes()) as NSUInteger;
let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger;
- blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length);
+ blit.copy_from_buffer(&self.buffer, src_offset, &dst.buffer(), dst_offset, length);
blit.end_encoding();
} else {
let src_shape = src_l.shape();
@@ -1521,6 +1514,28 @@ impl MetalStorage {
command_buffer.set_label("binary");
Ok(Self::new(buffer, device.clone(), dtype))
}
+
+ pub(crate) fn to_cpu<T: Clone>(&self) -> Result<Vec<T>> {
+ let length = self.buffer.length() as usize;
+ let size = self.dtype.size_in_bytes();
+ if length % size != 0 {
+ crate::bail!(
+ "The Metal buffer length is not aligned with dtype {:?}",
+ self.dtype
+ );
+ }
+ let buffer = self.device.new_buffer_managed(self.buffer.length())?;
+ {
+ let command_buffer = self.device.command_buffer()?;
+ command_buffer.set_label("to_cpu");
+ let blit = command_buffer.new_blit_command_encoder();
+ blit.set_label("blit_to_cpu");
+ blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
+ blit.end_encoding();
+ }
+ self.device.wait_until_completed()?;
+ Ok(read_to_vec(&buffer, length / size))
+ }
}
impl BackendDevice for MetalDevice {
@@ -1533,16 +1548,14 @@ impl BackendDevice for MetalDevice {
command_buffer.enqueue();
let command_buffer = Arc::new(RwLock::new(command_buffer));
let command_buffer_index = Arc::new(RwLock::new(0));
- let fence = device.new_fence();
- let kernels = Arc::new(Kernels::new(fence.clone()));
+ let kernels = Arc::new(Kernels::new());
let buffers = Arc::new(RwLock::new(HashMap::new()));
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
Ok(val) => val.parse()?,
- _ => 20,
+ _ => 10,
};
Ok(Self {
device,
- fence,
command_queue,
command_buffer,
command_buffer_index,
@@ -1567,21 +1580,8 @@ impl BackendDevice for MetalDevice {
}
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
- let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros")?;
- let command_buffer = self.command_buffer()?;
- command_buffer.set_label("zeros");
- let blit = command_buffer.new_blit_command_encoder();
- blit.wait_for_fence(&self.fence);
- blit.fill_buffer(
- &buffer,
- metal::NSRange {
- location: 0,
- length: buffer.length(),
- },
- 0,
- );
- blit.update_fence(&self.fence);
- blit.end_encoding();
+ let size = shape.elem_count() * dtype.size_in_bytes();
+ let buffer = self.allocate_zeros(size)?;
Ok(MetalStorage::new(buffer, self.clone(), dtype))
}
diff --git a/candle-core/src/quantized/ggml_file.rs b/candle-core/src/quantized/ggml_file.rs
index 1dd3d9c0..38238580 100644
--- a/candle-core/src/quantized/ggml_file.rs
+++ b/candle-core/src/quantized/ggml_file.rs
@@ -1,7 +1,9 @@
//! Support for the GGML file format.
-use super::{k_quants, GgmlDType};
-use crate::Result;
+#[cfg(feature = "metal")]
+use super::metal::load_quantized_metal;
+use super::{k_quants, GgmlDType, QStorage};
+use crate::{Device, Result};
use byteorder::{LittleEndian, ReadBytesExt};
use std::collections::HashMap;
@@ -121,11 +123,22 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
raw_data: &[u8],
size_in_bytes: usize,
dims: Vec<usize>,
+ device: &Device,
) -> Result<super::QTensor> {
let raw_data_ptr = raw_data.as_ptr();
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
- super::QTensor::new(data.to_vec(), dims)
+ let data: QStorage = match device {
+ Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())),
+ #[cfg(feature = "metal")]
+ Device::Metal(metal) => load_quantized_metal(metal, data)?,
+ #[cfg(not(feature = "metal"))]
+ Device::Metal(_metal) => {
+ crate::bail!("Metal backend requires `metal` feature")
+ }
+ device => unimplemented!("Implement quantized tensor for device {device:?}"),
+ };
+ super::QTensor::new(data, dims)
}
/// Creates a [Tensor] from a raw GGML tensor.
@@ -133,29 +146,50 @@ pub fn qtensor_from_ggml(
ggml_dtype: GgmlDType,
raw_data: &[u8],
dims: Vec<usize>,
+ device: &Device,
) -> Result<super::QTensor> {
let tensor_elems = dims.iter().product::<usize>();
- let blck_size = ggml_dtype.blck_size();
- if tensor_elems % blck_size != 0 {
+ let block_size = ggml_dtype.block_size();
+ if tensor_elems % block_size != 0 {
crate::bail!(
- "the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
+ "the number of elements {tensor_elems} is not divisible by the block size {block_size}"
)
}
- let size_in_bytes = tensor_elems / blck_size * ggml_dtype.type_size();
+ let size_in_bytes = tensor_elems / block_size * ggml_dtype.type_size();
match ggml_dtype {
- GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims),
- GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims),
- GgmlDType::Q4_0 => from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims),
- GgmlDType::Q4_1 => from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims),
- GgmlDType::Q5_0 => from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims),
- GgmlDType::Q5_1 => from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims),
- GgmlDType::Q8_0 => from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims),
- GgmlDType::Q2K => from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims),
- GgmlDType::Q3K => from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims),
- GgmlDType::Q4K => from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims),
- GgmlDType::Q5K => from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims),
- GgmlDType::Q6K => from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims),
+ GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims, device),
+ GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims, device),
+ GgmlDType::Q4_0 => {
+ from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device)
+ }
+ GgmlDType::Q4_1 => {
+ from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims, device)
+ }
+ GgmlDType::Q5_0 => {
+ from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims, device)
+ }
+ GgmlDType::Q5_1 => {
+ from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims, device)
+ }
+ GgmlDType::Q8_0 => {
+ from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims, device)
+ }
+ GgmlDType::Q2K => {
+ from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims, device)
+ }
+ GgmlDType::Q3K => {
+ from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims, device)
+ }
+ GgmlDType::Q4K => {
+ from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims, device)
+ }
+ GgmlDType::Q5K => {
+ from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims, device)
+ }
+ GgmlDType::Q6K => {
+ from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims, device)
+ }
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
}
}
@@ -163,6 +197,7 @@ pub fn qtensor_from_ggml(
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
reader: &mut R,
magic: VersionedMagic,
+ device: &Device,
) -> Result<(String, super::QTensor)> {
let n_dims = reader.read_u32::<LittleEndian>()?;
let name_len = reader.read_u32::<LittleEndian>()?;
@@ -183,11 +218,11 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>(
}
let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>();
let tensor_elems = dims.iter().product::<usize>();
- let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
+ let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.block_size();
// TODO: Mmap version to avoid copying the data around?
let mut raw_data = vec![0u8; size_in_bytes];
reader.read_exact(&mut raw_data)?;
- match qtensor_from_ggml(ggml_dtype, &raw_data, dims) {
+ match qtensor_from_ggml(ggml_dtype, &raw_data, dims, device) {
Ok(tensor) => Ok((name, tensor)),
Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
}
@@ -201,7 +236,10 @@ pub struct Content {
}
impl Content {
- pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Content> {
+ pub fn read<R: std::io::Seek + std::io::Read>(
+ reader: &mut R,
+ device: &Device,
+ ) -> Result<Content> {
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
reader.seek(std::io::SeekFrom::Start(0))?;
@@ -211,7 +249,7 @@ impl Content {
let mut tensors = HashMap::new();
while reader.stream_position()? != last_position {
- let (name, tensor) = read_one_tensor(reader, magic)?;
+ let (name, tensor) = read_one_tensor(reader, magic, device)?;
tensors.insert(name, tensor);
}
Ok(Self {
diff --git a/candle-core/src/quantized/gguf_file.rs b/candle-core/src/quantized/gguf_file.rs
index 587ffc0f..b729d4a0 100644
--- a/candle-core/src/quantized/gguf_file.rs
+++ b/candle-core/src/quantized/gguf_file.rs
@@ -3,7 +3,7 @@
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
use super::{GgmlDType, QTensor};
-use crate::Result;
+use crate::{Device, Result};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::collections::HashMap;
@@ -59,19 +59,25 @@ impl TensorInfo {
&self,
reader: &mut R,
tensor_data_offset: u64,
+ device: &Device,
) -> Result<QTensor> {
let tensor_elems = self.shape.elem_count();
- let blck_size = self.ggml_dtype.blck_size();
- if tensor_elems % blck_size != 0 {
+ let block_size = self.ggml_dtype.block_size();
+ if tensor_elems % block_size != 0 {
crate::bail!(
- "the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
+ "the number of elements {tensor_elems} is not divisible by the block size {block_size}"
)
}
- let size_in_bytes = tensor_elems / blck_size * self.ggml_dtype.type_size();
+ let size_in_bytes = tensor_elems / block_size * self.ggml_dtype.type_size();
let mut raw_data = vec![0u8; size_in_bytes];
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
reader.read_exact(&mut raw_data)?;
- super::ggml_file::qtensor_from_ggml(self.ggml_dtype, &raw_data, self.shape.dims().to_vec())
+ super::ggml_file::qtensor_from_ggml(
+ self.ggml_dtype,
+ &raw_data,
+ self.shape.dims().to_vec(),
+ device,
+ )
}
}
@@ -460,12 +466,13 @@ impl Content {
&self,
reader: &mut R,
name: &str,
+ device: &Device,
) -> Result<QTensor> {
let tensor_info = match self.tensor_infos.get(name) {
Some(tensor_info) => tensor_info,
None => crate::bail!("cannot find tensor info for {name}"),
};
- tensor_info.read(reader, self.tensor_data_offset)
+ tensor_info.read(reader, self.tensor_data_offset, device)
}
}
@@ -517,10 +524,9 @@ pub fn write<W: std::io::Seek + std::io::Write>(
"internal error, unexpected current position {tensor_start_pos} {offset} {pos}"
)
}
- let data_ptr = tensor.as_ptr();
- let size_in_bytes = tensor.storage_size_in_bytes();
- let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
- w.write_all(data)?;
+ let data = tensor.data()?;
+ let size_in_bytes = data.len();
+ w.write_all(&data)?;
let padding = 31 - (31 + size_in_bytes) % 32;
w.write_all(&vec![0u8; padding])?;
}
diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs
new file mode 100644
index 00000000..fe57ce14
--- /dev/null
+++ b/candle-core/src/quantized/metal.rs
@@ -0,0 +1,153 @@
+use super::{GgmlDType, QStorage};
+use crate::{DType, MetalDevice, MetalStorage, Result};
+use metal::Buffer;
+use std::sync::Arc;
+
+pub struct QMetalStorage {
+ dtype: GgmlDType,
+ device: MetalDevice,
+ buffer: Arc<Buffer>,
+}
+
+impl QMetalStorage {
+ pub fn dtype(&self) -> GgmlDType {
+ self.dtype
+ }
+
+ pub fn buffer(&self) -> &Buffer {
+ &self.buffer
+ }
+
+ pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: GgmlDType) -> Self {
+ Self {
+ device,
+ buffer,
+ dtype,
+ }
+ }
+
+ pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> {
+ let buffer = self.device.new_buffer_managed(self.buffer.length())?;
+ let command_buffer = self.device.command_buffer()?;
+ command_buffer.set_label("to_cpu");
+ let blit = command_buffer.new_blit_command_encoder();
+ blit.set_label("blit_to_cpu");
+ blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
+ blit.end_encoding();
+ self.device.wait_until_completed()?;
+ let mut out = vec![0.0; elem_count];
+ match self.dtype {
+ GgmlDType::F32 => {
+ let vec: Vec<f32> = read_to_vec(&buffer, elem_count);
+ use crate::quantized::k_quants::GgmlType;
+ f32::to_float(&vec, &mut out)?;
+ }
+ GgmlDType::F16 => {
+ let vec: Vec<half::f16> = read_to_vec(&buffer, elem_count);
+ use crate::quantized::k_quants::GgmlType;
+ half::f16::to_float(&vec, &mut out)?;
+ }
+ GgmlDType::Q4_0 => {
+ let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, elem_count);
+ use crate::quantized::k_quants::GgmlType;
+ crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
+ }
+ GgmlDType::Q4_1 => {
+ let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, elem_count);
+ use crate::quantized::k_quants::GgmlType;
+ crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
+ }
+ GgmlDType::Q5_0 => {
+ let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, elem_count);
+ use crate::quantized::k_quants::GgmlType;
+ crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
+ }
+ GgmlDType::Q5_1 => {
+ let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, elem_count);
+ use crate::quantized::k_quants::GgmlType;
+ crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
+ }
+ GgmlDType::Q8_0 => {
+ let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, elem_count);
+ use crate::quantized::k_quants::GgmlType;
+ crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
+ }
+ GgmlDType::Q8_1 => {
+ let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, elem_count);
+ use crate::quantized::k_quants::GgmlType;
+ crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
+ }
+ GgmlDType::Q2K => {
+ let vec: Vec<crate::quantized::BlockQ2K> =
+ read_to_vec(&buffer, elem_count / self.dtype.block_size());
+ use crate::quantized::k_quants::GgmlType;
+ crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
+ }
+ GgmlDType::Q3K => {
+ let vec: Vec<crate::quantized::BlockQ3K> =
+ read_to_vec(&buffer, elem_count / self.dtype.block_size());
+ use crate::quantized::k_quants::GgmlType;
+ crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
+ }
+ GgmlDType::Q4K => {
+ let vec: Vec<crate::quantized::BlockQ4K> =
+ read_to_vec(&buffer, elem_count / self.dtype.block_size());
+ use crate::quantized::k_quants::GgmlType;
+ crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
+ }
+ GgmlDType::Q5K => {
+ let vec: Vec<crate::quantized::BlockQ5K> =
+ read_to_vec(&buffer, elem_count / self.dtype.block_size());
+ use crate::quantized::k_quants::GgmlType;
+ crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
+ }
+ GgmlDType::Q6K => {
+ let vec: Vec<crate::quantized::BlockQ6K> =
+ read_to_vec(&buffer, elem_count / self.dtype.block_size());
+ use crate::quantized::k_quants::GgmlType;
+ crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
+ }
+ GgmlDType::Q8K => {
+ let vec: Vec<crate::quantized::BlockQ8K> =
+ read_to_vec(&buffer, elem_count / self.dtype.block_size());
+ use crate::quantized::k_quants::GgmlType;
+ crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
+ }
+ }
+
+ let buffer = self.device.new_buffer_with_data(&out)?;
+ Ok(MetalStorage::new(buffer, self.device.clone(), DType::F32))
+ }
+
+ pub fn quantize(&mut self, src: &MetalStorage) -> Result<()> {
+ // Quantization only happens on CPU for now.
+ let src = src.to_cpu::<f32>()?;
+ let elem_count = src.len();
+ let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
+ let mut qcpu_storage = crate::Device::Cpu.qzeros(elem_count, self.dtype)?;
+ qcpu_storage.quantize(&src)?;
+ let buffer = self.device.new_buffer_with_data(&qcpu_storage.data()?)?;
+ self.buffer = buffer;
+ Ok(())
+ }
+}
+
+pub fn load_quantized_metal<T: super::GgmlType + Send + Sync + 'static>(
+ device: &MetalDevice,
+ data: &[T],
+) -> Result<QStorage> {
+ let buffer = device.new_buffer_with_data(data)?;
+ let device = device.clone();
+ Ok(QStorage::Metal(QMetalStorage {
+ dtype: T::DTYPE,
+ device,
+ buffer,
+ }))
+}
+
+fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
+ let ptr = buffer.contents() as *const T;
+ assert!(!ptr.is_null());
+ let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
+ slice.to_vec()
+}
diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs
index 043733ae..1dc5fe8f 100644
--- a/candle-core/src/quantized/mod.rs
+++ b/candle-core/src/quantized/mod.rs
@@ -1,23 +1,125 @@
-use crate::{Device, Result, Shape, Tensor};
+#[cfg(feature = "metal")]
+use crate::{backend::BackendStorage, DType};
+use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor};
+use k_quants::*;
+use std::borrow::Cow;
#[cfg(target_feature = "avx")]
pub mod avx;
pub mod ggml_file;
pub mod gguf_file;
pub mod k_quants;
+#[cfg(feature = "metal")]
+pub mod metal;
#[cfg(target_feature = "neon")]
pub mod neon;
#[cfg(target_feature = "simd128")]
pub mod simd128;
pub mod utils;
+use half::f16;
pub use k_quants::GgmlType;
pub struct QTensor {
- data: Box<dyn QuantizedType>,
+ storage: QStorage,
shape: Shape,
}
+impl Device {
+ fn qzeros(&self, elem_count: usize, dtype: GgmlDType) -> Result<QStorage> {
+ match self {
+ Device::Cpu => {
+ let storage = dtype.cpu_zeros(elem_count);
+ Ok(QStorage::Cpu(storage))
+ }
+ #[cfg(feature = "metal")]
+ Device::Metal(metal) => {
+ let size = elem_count * dtype.type_size() / dtype.block_size();
+ let buffer = metal.allocate_zeros(size)?;
+ Ok(QStorage::Metal(metal::QMetalStorage::new(
+ buffer,
+ metal.clone(),
+ dtype,
+ )))
+ }
+ #[cfg(not(feature = "metal"))]
+ Device::Metal(_metal) => {
+ crate::bail!("Metal feature not activated");
+ }
+ Device::Cuda(_cuda) => {
+ crate::bail!("Cuda ggml quantization not supported");
+ }
+ }
+ }
+}
+
+pub enum QStorage {
+ Cpu(Box<dyn QuantizedType>),
+ #[cfg(feature = "metal")]
+ Metal(metal::QMetalStorage),
+}
+
+impl QStorage {
+ fn block_size(&self) -> usize {
+ match self {
+ QStorage::Cpu(storage) => storage.block_size(),
+ #[cfg(feature = "metal")]
+ QStorage::Metal(storage) => storage.dtype().block_size(),
+ }
+ }
+
+ fn dtype(&self) -> GgmlDType {
+ match self {
+ QStorage::Cpu(storage) => storage.dtype(),
+ #[cfg(feature = "metal")]
+ QStorage::Metal(storage) => storage.dtype(),
+ }
+ }
+
+ fn size_in_bytes(&self) -> usize {
+ match self {
+ QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
+ #[cfg(feature = "metal")]
+ QStorage::Metal(storage) => storage.buffer().length() as usize,
+ }
+ }
+
+ fn quantize(&mut self, src: &Storage) -> Result<()> {
+ match (self, src) {
+ (QStorage::Cpu(storage), Storage::Cpu(src)) => {
+ storage.from_float(src.as_slice::<f32>()?)?;
+ }
+ #[cfg(feature = "metal")]
+ (QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
+ _ => crate::bail!("Invalid dequantize storage locations do not match"),
+ }
+ Ok(())
+ }
+
+ fn dequantize(&self, elem_count: usize) -> Result<Storage> {
+ match self {
+ QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
+ #[cfg(feature = "metal")]
+ QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
+ }
+ }
+
+ fn data(&self) -> Result<Cow<[u8]>> {
+ match self {
+ QStorage::Cpu(storage) => {
+ let data_ptr = storage.as_ptr();
+ let size_in_bytes = storage.storage_size_in_bytes();
+ let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
+ Ok(Cow::from(data))
+ }
+ #[cfg(feature = "metal")]
+ QStorage::Metal(_storage) => {
+ crate::bail!("not implemented");
+ }
+ }
+ }
+}
+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum GgmlDType {
F32,
@@ -77,6 +179,25 @@ impl GgmlDType {
}
}
+ /// The block dtype
+ pub fn cpu_zeros(&self, elem_count: usize) -> Box<dyn QuantizedType> {
+ match self {
+ Self::F32 => Box::new(vec![f32::zeros(); elem_count]),
+ Self::F16 => Box::new(vec![f16::zeros(); elem_count]),
+ Self::Q4_0 => Box::new(vec![BlockQ4_0::zeros(); elem_count / BlockQ4_0::BLCK_SIZE]),
+ Self::Q4_1 => Box::new(vec![BlockQ4_1::zeros(); elem_count / BlockQ4_1::BLCK_SIZE]),
+ Self::Q5_0 => Box::new(vec![BlockQ5_0::zeros(); elem_count / BlockQ5_0::BLCK_SIZE]),
+ Self::Q5_1 => Box::new(vec![BlockQ5_1::zeros(); elem_count / BlockQ5_1::BLCK_SIZE]),
+ Self::Q8_0 => Box::new(vec![BlockQ8_0::zeros(); elem_count / BlockQ8_0::BLCK_SIZE]),
+ Self::Q8_1 => Box::new(vec![BlockQ8_1::zeros(); elem_count / BlockQ8_1::BLCK_SIZE]),
+ Self::Q2K => Box::new(vec![BlockQ2K::zeros(); elem_count / BlockQ2K::BLCK_SIZE]),
+ Self::Q3K => Box::new(vec![BlockQ3K::zeros(); elem_count / BlockQ3K::BLCK_SIZE]),
+ Self::Q4K => Box::new(vec![BlockQ4K::zeros(); elem_count / BlockQ4K::BLCK_SIZE]),
+ Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]),
+ Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]),
+ Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]),
+ }
+ }
/// The type size for blocks in bytes.
pub fn type_size(&self) -> usize {
use k_quants::*;
@@ -100,7 +221,7 @@ impl GgmlDType {
}
/// The block size, i.e. the number of elements stored in each block.
- pub fn blck_size(&self) -> usize {
+ pub fn block_size(&self) -> usize {
match self {
Self::F32 => 1,
Self::F16 => 1,
@@ -119,9 +240,13 @@ impl GgmlDType {
pub trait QuantizedType: Send + Sync {
fn dtype(&self) -> GgmlDType;
fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
- fn to_float(&self, ys: &mut [f32]) -> Result<()>;
+ fn dequantize(&self, elem_count: usize) -> Result<CpuStorage>;
fn storage_size_in_bytes(&self) -> usize;
fn as_ptr(&self) -> *const u8;
+ fn block_size(&self) -> usize;
+ #[allow(clippy::wrong_self_convention)]
+ fn from_float(&mut self, xs: &[f32]) -> Result<()>;
+ fn size(&self) -> usize;
}
impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
@@ -129,12 +254,26 @@ impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
k_quants::matmul(mkn, lhs, self.as_slice(), dst)
}
+ fn size(&self) -> usize {
+ self.len() * core::mem::size_of::<T>()
+ }
+
+ fn from_float(&mut self, xs: &[f32]) -> Result<()> {
+ T::from_float(xs, self)
+ }
+
fn dtype(&self) -> GgmlDType {
T::DTYPE
}
- fn to_float(&self, ys: &mut [f32]) -> Result<()> {
- T::to_float(self.as_slice(), ys)
+ fn block_size(&self) -> usize {
+ T::BLCK_SIZE
+ }
+
+ fn dequantize(&self, elem_count: usize) -> Result<CpuStorage> {
+ let mut ys = vec![0.0f32; elem_count];
+ T::to_float(self.as_slice(), &mut ys)?;
+ Ok(CpuStorage::F32(ys))
}
fn storage_size_in_bytes(&self) -> usize {
@@ -152,56 +291,49 @@ impl std::fmt::Debug for QTensor {
}
}
-fn check_shape<T: k_quants::GgmlType>(shape: &Shape) -> Result<()> {
+fn check_shape(shape: &Shape, block_size: usize) -> Result<()> {
let dims = shape.dims();
if dims.is_empty() {
crate::bail!("scalar tensor cannot be quantized {shape:?}")
}
- if dims[dims.len() - 1] % T::BLCK_SIZE != 0 {
+ if dims[dims.len() - 1] % block_size != 0 {
crate::bail!(
"quantized tensor must have their last dim divisible by block size {shape:?} {}",
- T::BLCK_SIZE
+ block_size
)
}
Ok(())
}
impl QTensor {
- pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
- data: Vec<T>,
- shape: S,
- ) -> Result<Self> {
+ pub fn new<S: Into<Shape>>(storage: QStorage, shape: S) -> Result<Self> {
let shape = shape.into();
- check_shape::<T>(&shape)?;
- Ok(Self {
- data: Box::new(data),
- shape,
- })
+ check_shape(&shape, storage.block_size())?;
+ Ok(Self { storage, shape })
}
- pub fn quantize<T: k_quants::GgmlType + Send + Sync + 'static>(src: &Tensor) -> Result<Self> {
+ pub fn quantize(src: &Tensor, dtype: GgmlDType) -> Result<Self> {
let shape = src.shape();
- check_shape::<T>(shape)?;
- let src = src
- .to_dtype(crate::DType::F32)?
- .flatten_all()?
- .to_vec1::<f32>()?;
- if src.len() % T::BLCK_SIZE != 0 {
+ let block_size = dtype.block_size();
+ check_shape(shape, block_size)?;
+ let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
+ let elem_count = shape.elem_count();
+ if elem_count % block_size != 0 {
crate::bail!(
"tensor size ({shape:?}) is not divisible by block size {}",
- T::BLCK_SIZE
+ block_size
)
}
- let mut data = vec![T::zeros(); src.len() / T::BLCK_SIZE];
- T::from_float(&src, &mut data)?;
+ let mut storage = src.device().qzeros(elem_count, dtype)?;
+ storage.quantize(&src.storage())?;
Ok(Self {
- data: Box::new(data),
+ storage,
shape: shape.clone(),
})
}
pub fn dtype(&self) -> GgmlDType {
- self.data.dtype()
+ self.storage.dtype()
}
pub fn rank(&self) -> usize {
@@ -213,21 +345,19 @@ impl QTensor {
}
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
- let mut f32_data = vec![0f32; self.shape.elem_count()];
- self.data.to_float(&mut f32_data)?;
- Tensor::from_vec(f32_data, &self.shape, device)
- }
-
- pub fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
- self.data.matmul_t(mkn, lhs, dst)
+ let storage = self.storage.dequantize(self.shape.elem_count())?;
+ let none = crate::op::BackpropOp::none();
+ let is_variable = false;
+ crate::tensor::from_storage(storage, self.shape.clone(), none, is_variable)
+ .to_device(device)
}
pub fn storage_size_in_bytes(&self) -> usize {
- self.data.storage_size_in_bytes()
+ self.storage.size_in_bytes()
}
- pub fn as_ptr(&self) -> *const u8 {
- self.data.as_ptr()
+ pub fn data(&self) -> Result<Cow<'_, [u8]>> {
+ self.storage.data()
}
}
@@ -294,17 +424,93 @@ impl crate::CustomOp1 for QTensor {
}
dst_shape.push(n);
let dst_shape = Shape::from(dst_shape);
- let storage = storage.as_slice::<f32>()?;
- let storage =
- &storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
+ #[allow(clippy::infallible_destructuring_match)]
+ let self_storage = match &self.storage {
+ QStorage::Cpu(storage) => storage,
+ #[cfg(feature = "metal")]
+ _ => crate::bail!("Invalid storage"),
+ };
+ let slice = storage.as_slice::<f32>()?;
+ let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
let mut dst_storage = vec![0f32; dst_shape.elem_count()];
- self.matmul_t(
- (dst_shape.elem_count() / n, k, n),
- storage,
- &mut dst_storage,
- )?;
+ self_storage.matmul_t((dst_shape.elem_count() / n, k, n), slice, &mut dst_storage)?;
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
}
+
+ #[cfg(feature = "metal")]
+ fn metal_fwd(
+ &self,
+ storage: &crate::MetalStorage,
+ layout: &crate::Layout,
+ ) -> Result<(crate::MetalStorage, Shape)> {
+ use crate::MetalError;
+
+ if !layout.is_contiguous() {
+ crate::bail!("input tensor is not contiguous {layout:?}")
+ }
+ let src_shape = layout.shape();
+ // self is transposed so n is first then k.
+ if src_shape.rank() < 2 {
+ crate::bail!("input tensor has only one dimension {layout:?}")
+ }
+ let (n, k) = self.shape.dims2()?;
+ let mut dst_shape = src_shape.dims().to_vec();
+
+ let (b, m) = match dst_shape.len() {
+ 3 => (dst_shape[0], dst_shape[1]),
+ 2 => (1, dst_shape[0]),
+ n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
+ };
+ let last_k = dst_shape.pop().unwrap();
+ if last_k != k {
+ crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
+ }
+ dst_shape.push(n);
+ let dst_shape = Shape::from(dst_shape);
+ let device = storage.device().clone();
+ let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
+ let (buffer, dtype) = match &self.storage {
+ QStorage::Metal(metal) => (metal.buffer(), metal.dtype()),
+ _ => unreachable!("Cannot call metal matmul on non metal QTensor"),
+ };
+ let command_buffer = device.command_buffer()?;
+ candle_metal_kernels::call_quantized_matmul_t(
+ device.device(),
+ &command_buffer,
+ device.kernels(),
+ dtype.into(),
+ (b, m, n, k),
+ storage.buffer(),
+ layout.start_offset() * storage.dtype().size_in_bytes(),
+ buffer,
+ &dst,
+ )
+ .map_err(MetalError::from)?;
+ let dst_storage = crate::MetalStorage::new(dst, device, DType::F32);
+ Ok((dst_storage, dst_shape))
+ }
+}
+
+#[cfg(feature = "metal")]
+impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
+ fn from(value: GgmlDType) -> Self {
+ match value {
+ GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0,
+ GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1,
+ GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0,
+ GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1,
+ GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0,
+ GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1,
+ GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K,
+ GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K,
+ GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K,
+ GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K,
+ GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K,
+ GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K,
+ GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
+ GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
+ }
+ }
}
impl crate::Module for QMatMul {
diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs
index d31e77a7..a7811ca5 100644
--- a/candle-core/tests/quantized_tests.rs
+++ b/candle-core/tests/quantized_tests.rs
@@ -1,6 +1,7 @@
use candle_core::{
bail,
quantized::{self, GgmlDType},
+ test_device,
test_utils::to_vec2_round,
Device, Module, Result, Tensor,
};
@@ -14,16 +15,48 @@ const GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS: f32 = 0.0075;
const GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS: f32 = 0.0040;
const GGML_MAX_DOT_PRODUCT_ERROR: f32 = 0.02;
-#[test]
-fn quantized_matmul() -> Result<()> {
- let cpu = &Device::Cpu;
+fn test_matmul(
+ device: &Device,
+ (b, m, n, k): (usize, usize, usize, usize),
+ dtype: GgmlDType,
+) -> Result<()> {
+ let lhs = (0..(m * k))
+ .map(|v| v as f32 / (m * k) as f32)
+ .collect::<Vec<_>>();
+ let rhs = (0..(k * n))
+ .map(|v| v as f32 / (n * k) as f32)
+ .collect::<Vec<_>>();
+
+ let lhs = Tensor::from_slice(&lhs, (m, k), device)?;
+ let rhs = Tensor::from_slice(&rhs, (k, n), device)?;
+ let mm = lhs.matmul(&rhs)?;
+ let qtensor = quantized::QTensor::quantize(&rhs.t()?, dtype)?;
+ let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
+ let res = matmul.forward(&lhs)?;
+
+ let error: f32 = ((&mm - &res)?.abs()? / &mm.abs()?)?
+ .sum_all()?
+ .to_scalar()?;
+ let error = error / (b * m * n) as f32;
+ assert!(
+ error <= 0.02,
+ "Error {error} is too big. \nExpected:\n {mm} \nFound:\n {res}\n for {dtype:?}"
+ );
+
+ Ok(())
+}
+
+fn quantized_matmul(device: &Device) -> Result<()> {
+ // TODO Enable this later when we enable cuda.
+ if device.is_cuda() {
+ return Ok(());
+ }
let (m, k, n) = (3, 64, 4);
let lhs = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
- let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?;
+ let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?;
let mut dst = vec![42.; 3 * 4];
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
let rhs = (0..(k * n)).map(|v| v as f32).collect::<Vec<_>>();
- let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?;
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
assert_eq!(
@@ -33,6 +66,7 @@ fn quantized_matmul() -> Result<()> {
341876.0, 994283.0, 1655709.0, 2301518.0
]
);
+ let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
let mm = tensor_lhs.matmul(&tensor_rhs)?;
assert_eq!(
mm.to_vec2::<f32>()?,
@@ -43,35 +77,49 @@ fn quantized_matmul() -> Result<()> {
]
);
- let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
+ let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&tensor_lhs)?;
- assert_eq!(
- to_vec2_round(&res, 0)?,
- &[
- [85120.0, 214562.0, 345455.0, 474748.0],
- [213475.0, 604465.0, 1000686.0, 1388317.0],
- [341876.0, 994283.0, 1655709.0, 2301518.0]
- ]
- );
+ match device {
+ Device::Metal(_) => assert_eq!(
+ to_vec2_round(&res, 0)?,
+ &[
+ [84946.0, 214126.0, 344757.0, 473798.0],
+ [213458.0, 604350.0, 1000469.0, 1387990.0],
+ [341970.0, 994574.0, 1656181.0, 2302182.0]
+ ]
+ ),
+ _ => assert_eq!(
+ to_vec2_round(&res, 0)?,
+ &[
+ [85120.0, 214562.0, 345455.0, 474748.0],
+ [213475.0, 604465.0, 1000686.0, 1388317.0],
+ [341876.0, 994283.0, 1655709.0, 2301518.0]
+ ]
+ ),
+ }
+
+ test_matmul(device, (1, 3, 4, 256), GgmlDType::Q4_0)?;
Ok(())
}
-#[test]
-fn quantized_matmul_neg() -> Result<()> {
- let cpu = &Device::Cpu;
+fn quantized_matmul_neg(device: &Device) -> Result<()> {
+ // TODO Enable this later when we enable cuda.
+ if device.is_cuda() {
+ return Ok(());
+ }
let (m, k, n) = (3, 64, 4);
let lhs = (0..(m * k))
.map(|v| v as f32 - (m * k) as f32 / 2.0)
.collect::<Vec<_>>();
- let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?;
+ let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?;
let mut dst = vec![42.; 3 * 4];
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
let rhs = (0..k * n)
.map(|v| v as f32 - (k * n) as f32 / 3.0)
.collect::<Vec<_>>();
- let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?;
+ let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
assert_eq!(
@@ -91,32 +139,56 @@ fn quantized_matmul_neg() -> Result<()> {
]
);
- let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
+ let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&tensor_lhs)?;
- assert_eq!(
- to_vec2_round(&res, 0)?,
- &[
- [243524.0, -19596.0, -285051.0, -549815.0],
- [23777.0, 21651.0, 19398.0, 18367.0],
- [-196472.0, 63012.0, 324585.0, 587902.0]
- ]
- );
+ match device {
+ Device::Metal(_) => assert_eq!(
+ to_vec2_round(&res, 0)?,
+ &[
+ [243666.0, -19714.0, -285433.0, -550453.0],
+ [23782.0, 21654.0, 19400.0, 18369.0],
+ [-196102.0, 63022.0, 324233.0, 587191.0]
+ ]
+ ),
+ _ => assert_eq!(
+ to_vec2_round(&res, 0)?,
+ &[
+ [243524.0, -19596.0, -285051.0, -549815.0],
+ [23777.0, 21651.0, 19398.0, 18367.0],
+ [-196472.0, 63012.0, 324585.0, 587902.0]
+ ]
+ ),
+ }
Ok(())
}
-#[test]
-fn quantize_q4_0() -> Result<()> {
- use k_quants::BlockQ4_0;
-
+test_device!(
+ quantized_matmul,
+ quantized_matmul_cpu,
+ quantized_matmul_cuda,
+ quantized_matmul_metal
+);
+test_device!(
+ quantized_matmul_neg,
+ quantized_matmul_neg_cpu,
+ quantized_matmul_neg_cuda,
+ quantized_matmul_neg_metal
+);
+
+fn quantize_q4_0(device: &Device) -> Result<()> {
+ // TODO Enable this later when we enable cuda.
+ if device.is_cuda() {
+ return Ok(());
+ }
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
- let mut dst = vec![0f32; 32 * 4];
- let mut quant = vec![BlockQ4_0::zeros(); 4];
- BlockQ4_0::from_float(&src, &mut quant)?;
- BlockQ4_0::to_float(&quant, dst.as_mut_slice())?;
+
+ let src = Tensor::from_slice(&src, (32 * 4,), device)?;
+ let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_0)?;
+ let dst = quant.dequantize(device)?;
assert_eq!(
- dst,
+ dst.to_vec1::<f32>()?,
&[
-0.0, -0.0, 3.875, 3.875, 3.875, 3.875, 7.75, 7.75, 7.75, 7.75, 11.625, 11.625, 11.625,
11.625, 15.5, 15.5, 15.5, 15.5, 19.375, 19.375, 19.375, 19.375, 23.25, 23.25, 23.25,
@@ -132,21 +204,21 @@ fn quantize_q4_0() -> Result<()> {
127.0, 127.0
]
);
- ggml_quantization_error_test::<BlockQ4_0>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
+ ggml_quantization_error_test(GgmlDType::Q4_0, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(())
}
-#[test]
-fn quantize_q4_1() -> Result<()> {
- use k_quants::BlockQ4_1;
-
+fn quantize_q4_1(device: &Device) -> Result<()> {
+ // TODO Enable this later when we enable cuda.
+ if device.is_cuda() {
+ return Ok(());
+ }
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
- let mut dst = vec![0f32; 32 * 4];
- let mut quant = vec![BlockQ4_1::zeros(); 4];
- BlockQ4_1::from_float(&src, &mut quant)?;
- BlockQ4_1::to_float(&quant, dst.as_mut_slice())?;
+ let src = Tensor::from_slice(&src, (32 * 4,), device)?;
+ let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?;
+ let dst = quant.dequantize(device)?;
assert_eq!(
- round_vector(&dst),
+ round_vector(&dst.to_vec1::<f32>()?),
&[
0.0, 0.0, 2.066, 2.066, 4.133, 4.133, 6.199, 6.199, 8.266, 8.266, 10.332, 10.332,
12.398, 12.398, 14.465, 14.465, 16.531, 16.531, 18.598, 18.598, 20.664, 20.664, 22.73,
@@ -162,21 +234,21 @@ fn quantize_q4_1() -> Result<()> {
118.73, 118.73, 120.797, 120.797, 122.863, 122.863, 124.93, 124.93, 126.996, 126.996
]
);
- ggml_quantization_error_test::<BlockQ4_1>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
+ ggml_quantization_error_test(GgmlDType::Q4_1, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(())
}
-#[test]
-fn quantize_q5_0() -> Result<()> {
- use k_quants::BlockQ5_0;
-
+fn quantize_q5_0(device: &Device) -> Result<()> {
+ // TODO Enable this later when we enable cuda.
+ if device.is_cuda() {
+ return Ok(());
+ }
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
- let mut dst = vec![0f32; 32 * 4];
- let mut quant = vec![BlockQ5_0::zeros(); 4];
- BlockQ5_0::from_float(&src, &mut quant)?;
- BlockQ5_0::to_float(&quant, dst.as_mut_slice())?;
+ let src = Tensor::from_slice(&src, (32 * 4,), device)?;
+ let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?;
+ let dst = quant.dequantize(device)?;
assert_eq!(
- round_vector(&dst),
+ round_vector(&dst.to_vec1::<f32>()?),
&[
-0.0, 1.938, 1.938, 3.875, 3.875, 5.813, 5.813, 7.75, 7.75, 9.688, 9.688, 11.625,
11.625, 13.563, 13.563, 15.5, 15.5, 17.438, 17.438, 19.375, 19.375, 21.313, 21.313,
@@ -192,21 +264,21 @@ fn quantize_q5_0() -> Result<()> {
119.063, 119.063, 119.063, 119.063, 127.0, 127.0, 127.0, 127.0
]
);
- ggml_quantization_error_test::<BlockQ5_0>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
+ ggml_quantization_error_test(GgmlDType::Q5_0, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(())
}
-#[test]
-fn quantize_q5_1() -> Result<()> {
- use k_quants::BlockQ5_1;
-
+fn quantize_q5_1(device: &Device) -> Result<()> {
+ // TODO Enable this later when we enable cuda.
+ if device.is_cuda() {
+ return Ok(());
+ }
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
- let mut dst = vec![0f32; 32 * 4];
- let mut quant = vec![BlockQ5_1::zeros(); 4];
- BlockQ5_1::from_float(&src, &mut quant)?;
- BlockQ5_1::to_float(&quant, dst.as_mut_slice())?;
+ let src = Tensor::from_slice(&src, (32 * 4,), device)?;
+ let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?;
+ let dst = quant.dequantize(device)?;
assert_eq!(
- dst,
+ round_vector(&dst.to_vec1::<f32>()?),
&[
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0,
@@ -220,13 +292,11 @@ fn quantize_q5_1() -> Result<()> {
124.0, 125.0, 126.0, 127.0
]
);
-
- ggml_quantization_error_test::<BlockQ5_1>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
+ ggml_quantization_error_test(GgmlDType::Q5_1, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(())
}
-/// Generates a small test vector ranging from -`bound` to `bound` with `size` steps
-fn get_test_vector(bound: f32, size: usize) -> (Vec<f32>, Vec<f32>) {
+fn get_test_vector2(bound: f32, size: usize, device: &Device) -> Result<Tensor> {
assert!(
size % crate::quantized::k_quants::QK_K == 0,
"size must be a multiple of {}",
@@ -236,10 +306,8 @@ fn get_test_vector(bound: f32, size: usize) -> (Vec<f32>, Vec<f32>) {
let src = (0..size)
.map(|v| (v as f32 - size as f32 / 2.) * bound / (size as f32 / 2.))
.collect::<Vec<_>>();
-
- let dst = vec![0f32; size];
assert_eq!([src[0], src[size / 2]], [-bound, 0.0]);
- (src, dst)
+ Tensor::from_vec(src, (size,), device)
}
/// Round a vector
@@ -288,11 +356,12 @@ fn calculate_rmse(a: &[f32], b: &[f32]) -> f32 {
/// Similar to the GGML quantization unit test:
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50
-fn ggml_quantization_error_test<T: GgmlType>(max_error: f32) -> Result<()> {
+fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f32) -> Result<()> {
let src = create_ggml_like_vector(0.0);
- let mut dst = vec![0.0; GGML_TEST_SIZE];
- let _quant = quantize_roundtrip::<T>(src.as_slice(), dst.as_mut_slice())?;
- let error = calculate_rmse(src.as_slice(), dst.as_slice());
+ let src = Tensor::from_slice(&src, (GGML_TEST_SIZE,), device)?;
+ let quant = quantized::QTensor::quantize(&src, dtype)?;
+ let dst = quant.dequantize(device)?;
+ let error = calculate_rmse(&src.to_vec1::<f32>()?, &dst.to_vec1::<f32>()?);
if error > max_error {
bail!(
"Quantization error {} exceeds max error {}",
@@ -303,19 +372,19 @@ fn ggml_quantization_error_test<T: GgmlType>(max_error: f32) -> Result<()> {
Ok(())
}
-fn quantize_roundtrip<T: GgmlType>(src: &[f32], dst: &mut [f32]) -> Result<Vec<T>> {
- let mut quant = vec![T::zeros(); src.len() / T::BLCK_SIZE];
- T::from_float(src, &mut quant)?;
- T::to_float(&quant, dst)?;
- Ok(quant)
-}
+fn quantize_q2k(device: &Device) -> Result<()> {
+ // TODO Enable this later when we enable cuda.
+ if device.is_cuda() {
+ return Ok(());
+ }
+ let dtype = GgmlDType::Q2K;
-#[test]
-fn quantize_q2k() -> Result<()> {
- use k_quants::BlockQ2K;
+ let src = get_test_vector2(0.5, 1024, device)?;
+ let quant = quantized::QTensor::quantize(&src, dtype)?;
+ let dst = quant.dequantize(device)?;
- let (src, mut dst) = get_test_vector(0.5, 1024);
- let _quant = quantize_roundtrip::<BlockQ2K>(src.as_slice(), dst.as_mut_slice())?;
+ let src = src.to_vec1::<f32>()?;
+ let dst = dst.to_vec1::<f32>()?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.1);
// Test some specific values
@@ -329,20 +398,30 @@ fn quantize_q2k() -> Result<()> {
[-0.499, -0.366, -0.249, 0.0, 0.295, 0.492]
);
- let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
- let _quant_big = quantize_roundtrip::<BlockQ2K>(src_big.as_slice(), dst_big.as_mut_slice())?;
+ let src_big = get_test_vector2(128.0, 1024, device)?;
+ let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
+ let dst_big = quant_big.dequantize(device)?;
+
+ let src_big = src_big.to_vec1::<f32>()?;
+ let dst_big = dst_big.to_vec1::<f32>()?;
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 6.0);
- ggml_quantization_error_test::<BlockQ2K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS)?;
+ ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS)?;
Ok(())
}
-#[test]
-fn quantize_q3k() -> Result<()> {
- use k_quants::BlockQ3K;
+fn quantize_q3k(device: &Device) -> Result<()> {
+ // TODO Enable this later when we enable cuda.
+ if device.is_cuda() {
+ return Ok(());
+ }
+ let dtype = GgmlDType::Q3K;
+ let src = get_test_vector2(0.5, 1024, device)?;
+ let quant = quantized::QTensor::quantize(&src, dtype)?;
+ let dst = quant.dequantize(device)?;
- let (src, mut dst) = get_test_vector(0.5, 1024);
- let _quant = quantize_roundtrip::<BlockQ3K>(src.as_slice(), dst.as_mut_slice())?;
+ let src = src.to_vec1::<f32>()?;
+ let dst = dst.to_vec1::<f32>()?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.03);
// Test some specific values
@@ -356,20 +435,30 @@ fn quantize_q3k() -> Result<()> {
[-0.493, -0.37, -0.243, -0.0, 0.292, 0.492]
);
- let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
- let _quant_big = quantize_roundtrip::<BlockQ3K>(src_big.as_slice(), dst_big.as_mut_slice())?;
+ let src_big = get_test_vector2(128.0, 1024, device)?;
+ let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
+ let dst_big = quant_big.dequantize(device)?;
+
+ let src_big = src_big.to_vec1::<f32>()?;
+ let dst_big = dst_big.to_vec1::<f32>()?;
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 3.5);
- ggml_quantization_error_test::<BlockQ3K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS)?;
+ ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS)?;
Ok(())
}
-#[test]
-fn quantize_q4k() -> Result<()> {
- use k_quants::BlockQ4K;
+fn quantize_q4k(device: &Device) -> Result<()> {
+ // TODO Enable this later when we enable cuda.
+ if device.is_cuda() {
+ return Ok(());
+ }
+ let dtype = GgmlDType::Q4K;
+ let src = get_test_vector2(0.5, 1024, device)?;
+ let quant = quantized::QTensor::quantize(&src, dtype)?;
+ let dst = quant.dequantize(device)?;
- let (src, mut dst) = get_test_vector(0.5, 1024);
- let _quant = quantize_roundtrip::<BlockQ4K>(src.as_slice(), dst.as_mut_slice())?;
+ let src = src.to_vec1::<f32>()?;
+ let dst = dst.to_vec1::<f32>()?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.017);
// Test some specific values
@@ -383,21 +472,31 @@ fn quantize_q4k() -> Result<()> {
[-0.5, -0.373, -0.25, 0.0, 0.288, 0.498]
);
- let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
- let _quant_big = quantize_roundtrip::<BlockQ4K>(src_big.as_slice(), dst_big.as_mut_slice())?;
+ let src_big = get_test_vector2(128.0, 1024, device)?;
+ let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
+ let dst_big = quant_big.dequantize(device)?;
+
+ let src_big = src_big.to_vec1::<f32>()?;
+ let dst_big = dst_big.to_vec1::<f32>()?;
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 4.5);
- ggml_quantization_error_test::<BlockQ4K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
+ ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(())
}
-#[test]
-fn quantize_q5k() -> Result<()> {
- use k_quants::BlockQ5K;
+fn quantize_q5k(device: &Device) -> Result<()> {
+ // TODO Enable this later when we enable cuda.
+ if device.is_cuda() {
+ return Ok(());
+ }
+ let dtype = GgmlDType::Q5K;
+ let src = get_test_vector2(0.5, 1024, device)?;
+ let quant = quantized::QTensor::quantize(&src, dtype)?;
+ let dst = quant.dequantize(device)?;
- let (src, mut dst) = get_test_vector(0.5, 1024);
- let _quant = quantize_roundtrip::<BlockQ5K>(src.as_slice(), dst.as_mut_slice())?;
- compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
+ let src = src.to_vec1::<f32>()?;
+ let dst = dst.to_vec1::<f32>()?;
+ compare_with_error(dst.as_slice(), src.as_slice(), 0.009);
// Test some specific values
assert_eq!(
@@ -410,21 +509,30 @@ fn quantize_q5k() -> Result<()> {
[-0.5, -0.373, -0.25, 0.0, 0.279, 0.499]
);
- let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
- let _quant_big = quantize_roundtrip::<BlockQ5K>(src_big.as_slice(), dst_big.as_mut_slice())?;
- compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.5);
+ let src_big = get_test_vector2(128.0, 1024, device)?;
+ let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
+ let dst_big = quant_big.dequantize(device)?;
- ggml_quantization_error_test::<BlockQ5K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
+ let src_big = src_big.to_vec1::<f32>()?;
+ let dst_big = dst_big.to_vec1::<f32>()?;
+ compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.5);
+ ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(())
}
-#[test]
-fn quantize_q6k() -> Result<()> {
- use k_quants::BlockQ6K;
+fn quantize_q6k(device: &Device) -> Result<()> {
+ // TODO Enable this later when we enable cuda.
+ if device.is_cuda() {
+ return Ok(());
+ }
+ let dtype = GgmlDType::Q6K;
+ let src = get_test_vector2(0.5, 1024, device)?;
+ let quant = quantized::QTensor::quantize(&src, dtype)?;
+ let dst = quant.dequantize(device)?;
- let (src, mut dst) = get_test_vector(0.5, 1024);
- let _quant = quantize_roundtrip::<BlockQ6K>(src.as_slice(), dst.as_mut_slice())?;
+ let src = src.to_vec1::<f32>()?;
+ let dst = dst.to_vec1::<f32>()?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
// Test some specific values
@@ -438,22 +546,31 @@ fn quantize_q6k() -> Result<()> {
[-0.497, -0.372, -0.25, -0.0, 0.284, 0.5]
);
- let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
- let _quant_big = quantize_roundtrip::<BlockQ6K>(src_big.as_slice(), dst_big.as_mut_slice())?;
- compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.0);
+ let src_big = get_test_vector2(128.0, 1024, device)?;
+ let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
+ let dst_big = quant_big.dequantize(device)?;
- ggml_quantization_error_test::<BlockQ6K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
+ let src_big = src_big.to_vec1::<f32>()?;
+ let dst_big = dst_big.to_vec1::<f32>()?;
+ compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.0);
+ ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(())
}
-#[test]
-fn quantize_q8k() -> Result<()> {
- use k_quants::BlockQ8K;
+fn quantize_q8k(device: &Device) -> Result<()> {
+ // TODO Enable this later when we enable cuda.
+ if device.is_cuda() {
+ return Ok(());
+ }
+ let dtype = GgmlDType::Q8K;
+ let src = get_test_vector2(0.5, 1024, device)?;
+ let quant = quantized::QTensor::quantize(&src, dtype)?;
+ let dst = quant.dequantize(device)?;
- let (src, mut dst) = get_test_vector(0.5, 1024);
- let _quant = quantize_roundtrip::<BlockQ8K>(src.as_slice(), dst.as_mut_slice())?;
- compare_with_error(dst.as_slice(), src.as_slice(), 0.003);
+ let src = src.to_vec1::<f32>()?;
+ let dst = dst.to_vec1::<f32>()?;
+ compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
// Test some specific values
assert_eq!(
@@ -466,15 +583,79 @@ fn quantize_q8k() -> Result<()> {
[-0.5, -0.375, -0.25, -0.0, 0.281, 0.499]
);
- let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
- let _quant_big = quantize_roundtrip::<BlockQ8K>(src_big.as_slice(), dst_big.as_mut_slice())?;
- compare_with_error(dst_big.as_slice(), src_big.as_slice(), 0.6);
+ let src_big = get_test_vector2(128.0, 1024, device)?;
+ let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
+ let dst_big = quant_big.dequantize(device)?;
- ggml_quantization_error_test::<BlockQ8K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
+ let src_big = src_big.to_vec1::<f32>()?;
+ let dst_big = dst_big.to_vec1::<f32>()?;
+ compare_with_error(dst_big.as_slice(), src_big.as_slice(), 0.6);
+ ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(())
}
+test_device!(
+ quantize_q4_0,
+ quantize_q4_0_cpu,
+ quantize_q4_0_cuda,
+ quantize_q4_0_metal
+);
+test_device!(
+ quantize_q4_1,
+ quantize_q4_1_cpu,
+ quantize_q4_1_cuda,
+ quantize_q4_1_metal
+);
+test_device!(
+ quantize_q5_0,
+ quantize_q5_0_cpu,
+ quantize_q5_0_cuda,
+ quantize_q5_0_metal
+);
+test_device!(
+ quantize_q5_1,
+ quantize_q5_1_cpu,
+ quantize_q5_1_cuda,
+ quantize_q5_1_metal
+);
+test_device!(
+ quantize_q2k,
+ quantize_q2k_cpu,
+ quantize_q2k_cuda,
+ quantize_q2k_metal
+);
+test_device!(
+ quantize_q3k,
+ quantize_q3k_cpu,
+ quantize_q3k_cuda,
+ quantize_q3k_metal
+);
+test_device!(
+ quantize_q4k,
+ quantize_q4k_cpu,
+ quantize_q4k_cuda,
+ quantize_q4k_metal
+);
+test_device!(
+ quantize_q5k,
+ quantize_q5k_cpu,
+ quantize_q5k_cuda,
+ quantize_q5k_metal
+);
+test_device!(
+ quantize_q6k,
+ quantize_q6k_cpu,
+ quantize_q6k_cuda,
+ quantize_q6k_metal
+);
+test_device!(
+ quantize_q8k,
+ quantize_q8k_cpu,
+ quantize_q8k_cuda,
+ quantize_q8k_metal
+);
+
/// Very simple dot product implementation
fn vec_dot_reference(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b).map(|(a, b)| a * b).sum()
@@ -591,6 +772,112 @@ fn get_random_tensors(
Ok((lhs, rhs, mm))
}
+#[macro_export]
+macro_rules! quantized_matmul {
+ // TODO: Switch to generating the two last arguments automatically once concat_idents is
+ // stable. https://github.com/rust-lang/rust/issues/29599
+ ($fn_name: ident, $fn_name_cpu: ident, $fn_name_cuda: ident, $fn_name_metal: ident, $dtype: expr) => {
+ fn $fn_name(device: &Device) -> Result<()> {
+ if device.is_cuda() {
+ // TODO Enable Cuda GGML sometime maybe.
+ return Ok(());
+ }
+ test_matmul(device, (1, 3, 4, 256), $dtype)?;
+ Ok(())
+ }
+
+ test_device!($fn_name, $fn_name_cpu, $fn_name_cuda, $fn_name_metal);
+ };
+}
+
+quantized_matmul!(
+ quantized_matmul_q4_0_bis,
+ quantized_matmul_q4_0_cpu,
+ quantized_matmul_q4_0_cuda,
+ quantized_matmul_q4_0_metal,
+ GgmlDType::Q4_0
+);
+quantized_matmul!(
+ quantized_matmul_q4_1_bis,
+ quantized_matmul_q4_1_cpu,
+ quantized_matmul_q4_1_cuda,
+ quantized_matmul_q4_1_metal,
+ GgmlDType::Q4_1
+);
+quantized_matmul!(
+ quantized_matmul_q5_0_bis,
+ quantized_matmul_q5_0_cpu,
+ quantized_matmul_q5_0_cuda,
+ quantized_matmul_q5_0_metal,
+ GgmlDType::Q5_0
+);
+quantized_matmul!(
+ quantized_matmul_q5_1_bis,
+ quantized_matmul_q5_1_cpu,
+ quantized_matmul_q5_1_cuda,
+ quantized_matmul_q5_1_metal,
+ GgmlDType::Q5_1
+);
+quantized_matmul!(
+ quantized_matmul_q8_0_bis,
+ quantized_matmul_q8_0_cpu,
+ quantized_matmul_q8_0_cuda,
+ quantized_matmul_q8_0_metal,
+ GgmlDType::Q8_0
+);
+// Not implemented in Ggml
+// quantized_matmul!(
+// quantized_matmul_q8_1_bis,
+// quantized_matmul_q8_1_cpu,
+// quantized_matmul_q8_1_cuda,
+// quantized_matmul_q8_1_metal,
+// GgmlDType::Q8_1
+// );
+// TODO This is bugged (also bugged in GGML
+quantized_matmul!(
+ quantized_matmul_q2k_bis,
+ quantized_matmul_q2k_cpu,
+ quantized_matmul_q2k_cuda,
+ quantized_matmul_q2k_metal,
+ GgmlDType::Q2K
+);
+quantized_matmul!(
+ quantized_matmul_q3k_bis,
+ quantized_matmul_q3k_cpu,
+ quantized_matmul_q3k_cuda,
+ quantized_matmul_q3k_metal,
+ GgmlDType::Q3K
+);
+quantized_matmul!(
+ quantized_matmul_q4k_bis,
+ quantized_matmul_q4k_cpu,
+ quantized_matmul_q4k_cuda,
+ quantized_matmul_q4k_metal,
+ GgmlDType::Q4K
+);
+quantized_matmul!(
+ quantized_matmul_q5k_bis,
+ quantized_matmul_q5k_cpu,
+ quantized_matmul_q5k_cuda,
+ quantized_matmul_q5k_metal,
+ GgmlDType::Q5K
+);
+quantized_matmul!(
+ quantized_matmul_q6k_bis,
+ quantized_matmul_q6k_cpu,
+ quantized_matmul_q6k_cuda,
+ quantized_matmul_q6k_metal,
+ GgmlDType::Q6K
+);
+// Not implemented on metal
+// quantized_matmul!(
+// quantized_matmul_q8k_bis,
+// quantized_matmul_q8k_cpu,
+// quantized_matmul_q8k_cuda,
+// quantized_matmul_q8k_metal,
+// GgmlDType::Q8K
+// );
+
#[test]
fn quantized_matmul_q2k() -> Result<()> {
use k_quants::BlockQ2K;
@@ -603,7 +890,7 @@ fn quantized_matmul_q2k() -> Result<()> {
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
- let rhs = quantized::QTensor::quantize::<BlockQ2K>(&rhs)?;
+ let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q2K)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?;
@@ -629,7 +916,7 @@ fn quantized_matmul_q3k() -> Result<()> {
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
- let rhs = quantized::QTensor::quantize::<BlockQ3K>(&rhs)?;
+ let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q3K)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?;
@@ -655,7 +942,7 @@ fn quantized_matmul_q4k() -> Result<()> {
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
- let rhs = quantized::QTensor::quantize::<BlockQ4K>(&rhs)?;
+ let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q4K)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?;
@@ -681,7 +968,7 @@ fn quantized_matmul_q5k() -> Result<()> {
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
- let rhs = quantized::QTensor::quantize::<BlockQ5K>(&rhs)?;
+ let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q5K)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?;
@@ -708,7 +995,7 @@ fn quantized_matmul_q6k() -> Result<()> {
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
- let rhs = quantized::QTensor::quantize::<BlockQ6K>(&rhs)?;
+ let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q6K)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?;
@@ -733,7 +1020,7 @@ fn quantized_matmul_q8k() -> Result<()> {
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
- let rhs = quantized::QTensor::quantize::<BlockQ8K>(&rhs)?;
+ let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q8K)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?;
diff --git a/candle-examples/examples/blip/main.rs b/candle-examples/examples/blip/main.rs
index a1051a8e..15e36476 100644
--- a/candle-examples/examples/blip/main.rs
+++ b/candle-examples/examples/blip/main.rs
@@ -106,17 +106,17 @@ pub fn main() -> anyhow::Result<()> {
let config = blip::Config::image_captioning_large();
+ let device = candle_examples::device(args.cpu)?;
let (image_embeds, device, mut model) = if args.quantized {
let device = Device::Cpu;
let image = load_image(args.image)?.to_device(&device)?;
println!("loaded image {image:?}");
- let vb = quantized_blip::VarBuilder::from_gguf(model_file)?;
+ let vb = quantized_blip::VarBuilder::from_gguf(model_file, &device)?;
let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
(image_embeds, device, Model::Q(model))
} else {
- let device = candle_examples::device(args.cpu)?;
let image = load_image(args.image)?.to_device(&device)?;
println!("loaded image {image:?}");
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs
index 0ceb27af..9d42dcc8 100644
--- a/candle-examples/examples/llama2-c/main.rs
+++ b/candle-examples/examples/llama2-c/main.rs
@@ -262,7 +262,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
.extension()
.map_or(false, |v| v == "safetensors");
let (model, config) = if is_gguf {
- let vb = qmodel::VarBuilder::from_gguf(config_path)?;
+ let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?;
let (_vocab_size, dim) = vb
.get_no_shape("model.embed_tokens.weight")?
.shape()
@@ -279,13 +279,13 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
(config.seq_len, config.head_size() / 2),
"rot.freq_cis_real",
)?
- .dequantize(&candle::Device::Cpu)?;
+ .dequantize(&device)?;
let freq_cis_imag = vb
.get(
(config.seq_len, config.head_size() / 2),
"rot.freq_cis_imag",
)?
- .dequantize(&candle::Device::Cpu)?;
+ .dequantize(&device)?;
let fake_vb = candle_nn::VarBuilder::from_tensors(
[
@@ -295,7 +295,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
.into_iter()
.collect(),
candle::DType::F32,
- &candle::Device::Cpu,
+ &device,
);
let cache = model::Cache::new(true, &config, fake_vb)?;
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs
index 5ed5e5cb..bad86098 100644
--- a/candle-examples/examples/mistral/main.rs
+++ b/candle-examples/examples/mistral/main.rs
@@ -244,13 +244,14 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let config = Config::config_7b_v0_1(args.use_flash_attn);
+ let device = candle_examples::device(args.cpu)?;
let (model, device) = if args.quantized {
let filename = &filenames[0];
- let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
+ let vb =
+ candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
let model = QMistral::new(&config, vb)?;
- (Model::Quantized(model), Device::Cpu)
+ (Model::Quantized(model), device)
} else {
- let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs
index 69eed84f..39f4fd58 100644
--- a/candle-examples/examples/phi/main.rs
+++ b/candle-examples/examples/phi/main.rs
@@ -307,18 +307,21 @@ fn main() -> Result<()> {
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
};
- let (model, device) = if args.quantized {
- let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?;
+ let device = candle_examples::device(args.cpu)?;
+ let model = if args.quantized {
let config = config();
+ let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
+ &filenames[0],
+ &device,
+ )?;
let model = match args.model {
WhichModel::V2 | WhichModel::V2Old => QMixFormer::new_v2(&config, vb)?,
_ => QMixFormer::new(&config, vb)?,
};
- (Model::Quantized(model), Device::Cpu)
+ Model::Quantized(model)
} else {
- let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
- let model = match args.model {
+ match args.model {
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => {
let config_filename = repo.get("config.json")?;
let config = std::fs::read_to_string(config_filename)?;
@@ -334,8 +337,7 @@ fn main() -> Result<()> {
let config = config();
Model::MixFormer(MixFormer::new(&config, vb)?)
}
- };
- (model, device)
+ }
};
println!("loaded the model in {:?}", start.elapsed());
diff --git a/candle-examples/examples/quantized-t5/main.rs b/candle-examples/examples/quantized-t5/main.rs
index 0ea2e0bd..ed3f1030 100644
--- a/candle-examples/examples/quantized-t5/main.rs
+++ b/candle-examples/examples/quantized-t5/main.rs
@@ -132,7 +132,8 @@ impl T5ModelBuilder {
}
pub fn build_model(&self) -> Result<t5::T5ForConditionalGeneration> {
- let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?;
+ let device = Device::Cpu;
+ let vb = t5::VarBuilder::from_gguf(&self.weights_filename, &device)?;
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
}
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs
index bfc6de53..34c44233 100644
--- a/candle-examples/examples/quantized/main.rs
+++ b/candle-examples/examples/quantized/main.rs
@@ -9,7 +9,7 @@ use std::io::Write;
use tokenizers::Tokenizer;
use candle::quantized::{ggml_file, gguf_file};
-use candle::{Device, Tensor};
+use candle::Tensor;
use candle_transformers::generation::LogitsProcessor;
use candle_examples::token_output_stream::TokenOutputStream;
@@ -361,6 +361,7 @@ fn main() -> anyhow::Result<()> {
let model_path = args.model()?;
let mut file = std::fs::File::open(&model_path)?;
let start = std::time::Instant::now();
+ let device = candle_examples::device(false)?;
let mut model = match model_path.extension().and_then(|v| v.to_str()) {
Some("gguf") => {
@@ -369,7 +370,7 @@ fn main() -> anyhow::Result<()> {
for (_, tensor) in model.tensor_infos.iter() {
let elem_count = tensor.shape.elem_count();
total_size_in_bytes +=
- elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size();
+ elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
}
println!(
"loaded {:?} tensors ({}) in {:.2}s",
@@ -377,15 +378,16 @@ fn main() -> anyhow::Result<()> {
&format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(),
);
- ModelWeights::from_gguf(model, &mut file)?
+ ModelWeights::from_gguf(model, &mut file, &device)?
}
Some("ggml" | "bin") | Some(_) | None => {
- let model = ggml_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
+ let model = ggml_file::Content::read(&mut file, &device)
+ .map_err(|e| e.with_path(model_path))?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensors.iter() {
let elem_count = tensor.shape().elem_count();
total_size_in_bytes +=
- elem_count * tensor.dtype().type_size() / tensor.dtype().blck_size();
+ elem_count * tensor.dtype().type_size() / tensor.dtype().block_size();
}
println!(
"loaded {:?} tensors ({}) in {:.2}s",
@@ -486,7 +488,7 @@ fn main() -> anyhow::Result<()> {
let start_prompt_processing = std::time::Instant::now();
let mut next_token = {
- let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
+ let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?;
logits_processor.sample(&logits)?
@@ -507,7 +509,7 @@ fn main() -> anyhow::Result<()> {
let start_post_prompt = std::time::Instant::now();
let mut sampled = 0;
for index in 0..to_sample {
- let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
+ let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, prompt_tokens.len() + index)?;
let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
diff --git a/candle-examples/examples/replit-code/main.rs b/candle-examples/examples/replit-code/main.rs
index 0f72b862..b7f767b9 100644
--- a/candle-examples/examples/replit-code/main.rs
+++ b/candle-examples/examples/replit-code/main.rs
@@ -236,16 +236,15 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
+ let device = candle_examples::device(args.cpu)?;
let config = Config::replit_code_v1_5_3b();
- let (model, device) = if args.quantized {
- let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
- let model = Model::Q(Q::new(&config, vb.pp("transformer"))?);
- (model, Device::Cpu)
+ let model = if args.quantized {
+ let vb =
+ candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename, &device)?;
+ Model::Q(Q::new(&config, vb.pp("transformer"))?)
} else {
- let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
- let model = Model::M(M::new(&config, vb.pp("transformer"))?);
- (model, device)
+ Model::M(M::new(&config, vb.pp("transformer"))?)
};
println!("loaded the model in {:?}", start.elapsed());
diff --git a/candle-examples/examples/stable-lm/main.rs b/candle-examples/examples/stable-lm/main.rs
index 0535aa70..ccd924a4 100644
--- a/candle-examples/examples/stable-lm/main.rs
+++ b/candle-examples/examples/stable-lm/main.rs
@@ -234,13 +234,14 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let config = Config::stablelm_3b_4e1t(args.use_flash_attn);
+ let device = candle_examples::device(args.cpu)?;
let (model, device) = if args.quantized {
let filename = &filenames[0];
- let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
+ let vb =
+ candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
let model = QStableLM::new(&config, vb)?;
(Model::Quantized(model), Device::Cpu)
} else {
- let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs
index 5be81f2d..6ea34613 100644
--- a/candle-examples/examples/whisper/main.rs
+++ b/candle-examples/examples/whisper/main.rs
@@ -557,8 +557,10 @@ fn main() -> Result<()> {
println!("loaded mel: {:?}", mel.dims());
let mut model = if args.quantized {
- let vb =
- candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?;
+ let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
+ &weights_filename,
+ &device,
+ )?;
Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?)
} else {
let vb =
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index c872dc60..201af97e 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -15,6 +15,7 @@ const CAST: &str = include_str!("cast.metal");
const REDUCE: &str = include_str!("reduce.metal");
const CONV: &str = include_str!("conv.metal");
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
+const QUANTIZED: &str = include_str!("quantized.metal");
/// Most kernels apply similarly across the tensors
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
@@ -62,6 +63,8 @@ macro_rules! primitive {
};
}
primitive!(usize);
+primitive!(i64);
+primitive!(i32);
primitive!(u32);
primitive!(f32);
@@ -117,6 +120,7 @@ pub enum Source {
Reduce,
Mfa,
Conv,
+ Quantized,
}
macro_rules! ops{
@@ -215,17 +219,15 @@ type Pipelines = HashMap<(&'static str, Option<ConstantValues>), ComputePipeline
pub struct Kernels {
libraries: RwLock<Libraries>,
pipelines: RwLock<Pipelines>,
- fence: metal::Fence,
}
impl Kernels {
- pub fn new(fence: metal::Fence) -> Self {
+ pub fn new() -> Self {
let libraries = RwLock::new(Libraries::new());
let pipelines = RwLock::new(Pipelines::new());
Self {
libraries,
pipelines,
- fence,
}
}
@@ -239,6 +241,7 @@ impl Kernels {
Source::Cast => CAST,
Source::Reduce => REDUCE,
Source::Conv => CONV,
+ Source::Quantized => QUANTIZED,
Source::Mfa => panic!("Invalid lib"),
}
}
@@ -345,7 +348,6 @@ pub fn call_unary_contiguous(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
- encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, input, output));
@@ -354,7 +356,6 @@ pub fn call_unary_contiguous(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -376,7 +377,6 @@ pub fn call_unary_strided(
let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder();
- encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
@@ -398,7 +398,6 @@ pub fn call_unary_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -417,7 +416,6 @@ pub fn call_binary_contiguous(
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
- encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, left, right, output));
@@ -428,7 +426,6 @@ pub fn call_binary_contiguous(
encoder.use_resource(right, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -453,7 +450,6 @@ pub fn call_binary_strided(
let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder();
let width: usize = shape.iter().product();
- encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
@@ -478,7 +474,6 @@ pub fn call_binary_strided(
encoder.use_resource(right_input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -497,7 +492,6 @@ pub fn call_cast_contiguous(
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
- encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, (input, input_offset), output));
@@ -506,7 +500,6 @@ pub fn call_cast_contiguous(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -526,7 +519,6 @@ pub fn call_cast_strided(
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
- encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
@@ -548,7 +540,6 @@ pub fn call_cast_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -568,7 +559,6 @@ pub fn call_reduce_contiguous(
let elements_to_sum = length / out_length;
let encoder = command_buffer.new_compute_command_encoder();
- encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -597,7 +587,6 @@ pub fn call_reduce_contiguous(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -619,7 +608,6 @@ pub fn call_reduce_strided(
let elements_to_sum = length / out_length;
let encoder = command_buffer.new_compute_command_encoder();
- encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -655,7 +643,6 @@ pub fn call_reduce_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -674,7 +661,6 @@ pub fn call_last_softmax(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
- encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -705,7 +691,6 @@ pub fn call_last_softmax(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -725,7 +710,6 @@ pub fn call_affine(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder();
- encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, add, input, output));
@@ -734,7 +718,6 @@ pub fn call_affine(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -757,7 +740,6 @@ pub fn call_affine_strided(
let size: usize = shape.iter().product();
let encoder = command_buffer.new_compute_command_encoder();
- encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -778,7 +760,6 @@ pub fn call_affine_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -797,7 +778,6 @@ pub fn call_powf(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder();
- encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, input, output));
@@ -806,7 +786,6 @@ pub fn call_powf(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -828,7 +807,6 @@ pub fn call_powf_strided(
let size: usize = shape.iter().product();
let encoder = command_buffer.new_compute_command_encoder();
- encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -848,7 +826,6 @@ pub fn call_powf_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -867,7 +844,6 @@ pub fn call_elu(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder();
- encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, input, output));
@@ -876,7 +852,6 @@ pub fn call_elu(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -898,7 +873,6 @@ pub fn call_elu_strided(
let size: usize = shape.iter().product();
let encoder = command_buffer.new_compute_command_encoder();
- encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -918,7 +892,6 @@ pub fn call_elu_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -940,7 +913,6 @@ pub fn call_where_cond_strided(
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
let encoder = command_buffer.new_compute_command_encoder();
- encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
let size: usize = shape.iter().product();
@@ -969,7 +941,6 @@ pub fn call_where_cond_strided(
encoder.use_resource(right, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -996,7 +967,6 @@ pub fn call_index_select(
let encoder = command_buffer.new_compute_command_encoder();
- encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -1019,7 +989,6 @@ pub fn call_index_select(
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -1048,7 +1017,6 @@ pub fn call_gather(
let encoder = command_buffer.new_compute_command_encoder();
- encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -1071,7 +1039,6 @@ pub fn call_gather(
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -1100,7 +1067,6 @@ pub fn call_scatter_add(
let encoder = command_buffer.new_compute_command_encoder();
- encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -1123,7 +1089,6 @@ pub fn call_scatter_add(
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -1153,7 +1118,6 @@ pub fn call_index_add(
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
let encoder = command_buffer.new_compute_command_encoder();
- encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -1177,7 +1141,6 @@ pub fn call_index_add(
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -1381,7 +1344,6 @@ pub fn call_gemm(
let block_bytes = block_elements * bytes;
let encoder = command_buffer.new_compute_command_encoder();
- encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_threadgroup_memory_length(0, block_bytes.into());
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
@@ -1421,12 +1383,10 @@ pub fn call_gemm(
height: 1,
depth: 1,
};
- // println!("grid size {grid_size:?} group size {group_size:?}");
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_size, group_size);
- encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
@@ -1451,7 +1411,6 @@ pub fn call_im2col1d_strided(
let encoder = command_buffer.new_compute_command_encoder();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
- encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
@@ -1471,7 +1430,6 @@ pub fn call_im2col1d_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
@@ -1501,7 +1459,6 @@ pub fn call_im2col_strided(
let encoder = command_buffer.new_compute_command_encoder();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
- encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
@@ -1523,7 +1480,6 @@ pub fn call_im2col_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
@@ -1549,7 +1505,6 @@ pub fn call_upsample_nearest_2d(
let scale_h = shape[3] as f32 / out_h as f32;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
let encoder = command_buffer.new_compute_command_encoder();
- encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
@@ -1567,7 +1522,176 @@ pub fn call_upsample_nearest_2d(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.update_fence(&kernels.fence);
+ encoder.end_encoding();
+
+ Ok(())
+}
+
+#[derive(Debug, Clone, Copy)]
+pub enum GgmlDType {
+ Q4_0,
+ Q4_1,
+ Q5_0,
+ Q5_1,
+ Q8_0,
+ Q8_1,
+ Q2K,
+ Q3K,
+ Q4K,
+ Q5K,
+ Q6K,
+ Q8K,
+ F16,
+ F32,
+}
+
+pub fn call_quantized_matmul_t(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ dtype: GgmlDType,
+ (b, m, n, k): (usize, usize, usize, usize),
+ lhs: &Buffer,
+ lhs_offset: usize,
+ rhs: &Buffer,
+ output: &Buffer,
+) -> Result<(), MetalKernelError> {
+ // Everything is in reverse
+ let ne00 = k as i64;
+ let ne01 = n as i64;
+ let ne02 = b as i64;
+ let ne03 = 1 as i64;
+
+ let nb00 = 0i64;
+ let nb01 = 0 as i64;
+ let nb02 = 0 as i64;
+
+ let ne10 = k as i64;
+ let ne11 = m as i64;
+ let ne12 = b as i64;
+ let ne13 = 1 as i64;
+
+ let nb10 = 0i64;
+ let nb11 = 0i64;
+ let nb12 = 0i64;
+
+ let ne0 = n as i64;
+ let ne1 = m as i64;
+ let r2: u32 = (ne12 / ne02) as u32;
+ let r3: u32 = (ne13 / ne03) as u32;
+
+ let (nth0, nth1, align) = match dtype {
+ GgmlDType::Q4_0
+ | GgmlDType::Q4_1
+ | GgmlDType::Q5_0
+ | GgmlDType::Q5_1
+ | GgmlDType::Q8_0
+ | GgmlDType::Q8_1 => {
+ let nth0 = 8;
+ let nth1 = 8;
+ let align = 8;
+ (nth0, nth1, align)
+ }
+ GgmlDType::Q2K => {
+ // Fixing a bug in Metal for GGML
+ let nth0 = 4;
+ let nth1 = 8;
+ let align = 4;
+ (nth0, nth1, align)
+ }
+ GgmlDType::Q4K => {
+ let nth0 = 4;
+ let nth1 = 8;
+ let align = 4;
+ (nth0, nth1, align)
+ }
+ GgmlDType::Q3K | GgmlDType::Q5K => {
+ let nth0 = 2;
+ let nth1 = 32;
+ let align = 4;
+ (nth0, nth1, align)
+ }
+ GgmlDType::Q6K => {
+ let nth0 = 2;
+ let nth1 = 32;
+ let align = 2;
+ (nth0, nth1, align)
+ }
+ GgmlDType::F16 | GgmlDType::Q8K => {
+ // Original implem uses rows
+ let nth0 = 32;
+ let nth1 = 1;
+ let align = 8;
+ (nth0, nth1, align)
+ }
+ GgmlDType::F32 => {
+ let nth0 = 32;
+ let nth1 = 1;
+ let align = 8;
+ (nth0, nth1, align)
+ }
+ };
+ let thread_groups_count = MTLSize {
+ width: divide(ne01 as usize, align),
+ height: ne11 as u64,
+ depth: (ne12 * ne13) as u64,
+ };
+ let threads_per_threadgroup = MTLSize {
+ width: nth0,
+ height: nth1,
+ depth: 1,
+ };
+ let name = match dtype {
+ GgmlDType::Q4_0 => "kernel_mul_mv_q4_0_f32",
+ GgmlDType::Q4_1 => "kernel_mul_mv_q4_1_f32",
+ GgmlDType::Q5_0 => "kernel_mul_mv_q5_0_f32",
+ GgmlDType::Q5_1 => "kernel_mul_mv_q5_1_f32",
+ GgmlDType::Q8_0 => "kernel_mul_mv_q8_0_f32",
+ GgmlDType::Q8_1 => "kernel_mul_mv_q8_1_f32",
+ GgmlDType::Q2K => "kernel_mul_mv_q2_K_f32",
+ GgmlDType::Q3K => "kernel_mul_mv_q3_K_f32",
+ GgmlDType::Q4K => "kernel_mul_mv_q4_K_f32",
+ GgmlDType::Q5K => "kernel_mul_mv_q5_K_f32",
+ GgmlDType::Q6K => "kernel_mul_mv_q6_K_f32",
+ GgmlDType::Q8K => "kernel_mul_mv_q8_K_f32",
+ GgmlDType::F16 => "kernel_mul_mv_f16_f32",
+ GgmlDType::F32 => "kernel_mul_mv_f32_f32",
+ };
+
+ let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?;
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(
+ encoder,
+ (
+ rhs,
+ (lhs, lhs_offset),
+ output,
+ ne00,
+ ne01,
+ ne02,
+ nb00,
+ nb01,
+ nb02,
+ ne10,
+ ne11,
+ ne12,
+ nb10,
+ nb11,
+ nb12,
+ ne0,
+ ne1,
+ r2,
+ r3
+ )
+ );
+ encoder.set_threadgroup_memory_length(0, 8192);
+ encoder.use_resource(lhs, metal::MTLResourceUsage::Read);
+ encoder.use_resource(rhs, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+
+ encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup);
encoder.end_encoding();
Ok(())
diff --git a/candle-metal-kernels/src/quantized.metal b/candle-metal-kernels/src/quantized.metal
new file mode 100644
index 00000000..9aa7b502
--- /dev/null
+++ b/candle-metal-kernels/src/quantized.metal
@@ -0,0 +1,5107 @@
+#include <metal_stdlib>
+
+using namespace metal;
+
+#define MAX(x, y) ((x) > (y) ? (x) : (y))
+#define MIN(x, y) ((x) < (y) ? (x) : (y))
+#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
+
+#define QK4_0 32
+#define QR4_0 2
+typedef struct {
+ half d; // delta
+ uint8_t qs[QK4_0 / 2]; // nibbles / quants
+} block_q4_0;
+
+#define QK4_1 32
+typedef struct {
+ half d; // delta
+ half m; // min
+ uint8_t qs[QK4_1 / 2]; // nibbles / quants
+} block_q4_1;
+
+#define QK5_0 32
+typedef struct {
+ half d; // delta
+ uint8_t qh[4]; // 5-th bit of quants
+ uint8_t qs[QK5_0 / 2]; // nibbles / quants
+} block_q5_0;
+
+#define QK5_1 32
+typedef struct {
+ half d; // delta
+ half m; // min
+ uint8_t qh[4]; // 5-th bit of quants
+ uint8_t qs[QK5_1 / 2]; // nibbles / quants
+} block_q5_1;
+
+#define QK8_0 32
+typedef struct {
+ half d; // delta
+ int8_t qs[QK8_0]; // quants
+} block_q8_0;
+
+#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
+
+enum ggml_sort_order {
+ GGML_SORT_ASC,
+ GGML_SORT_DESC,
+};
+
+// general-purpose kernel for addition, multiplication and division of two tensors
+// pros: works for non-contiguous tensors, supports broadcast across all dims
+// cons: not very efficient
+kernel void kernel_add(
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant uint64_t & nb13,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ constant int64_t & offs,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig.z;
+ const int64_t i02 = tgpig.y;
+ const int64_t i01 = tgpig.x;
+
+ const int64_t i13 = i03 % ne13;
+ const int64_t i12 = i02 % ne12;
+ const int64_t i11 = i01 % ne11;
+
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
+
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ const int i10 = i0 % ne10;
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10));
+ }
+}
+
+kernel void kernel_mul(
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant uint64_t & nb13,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig.z;
+ const int64_t i02 = tgpig.y;
+ const int64_t i01 = tgpig.x;
+
+ const int64_t i13 = i03 % ne13;
+ const int64_t i12 = i02 % ne12;
+ const int64_t i11 = i01 % ne11;
+
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
+
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ const int i10 = i0 % ne10;
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10));
+ }
+}
+
+kernel void kernel_div(
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant uint64_t & nb13,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig.z;
+ const int64_t i02 = tgpig.y;
+ const int64_t i01 = tgpig.x;
+
+ const int64_t i13 = i03 % ne13;
+ const int64_t i12 = i02 % ne12;
+ const int64_t i11 = i01 % ne11;
+
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
+
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ const int i10 = i0 % ne10;
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10));
+ }
+}
+
+// assumption: src1 is a row
+// broadcast src1 into src0
+kernel void kernel_add_row(
+ device const float4 * src0,
+ device const float4 * src1,
+ device float4 * dst,
+ constant uint64_t & nb [[buffer(28)]],
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] + src1[tpig % nb];
+}
+
+kernel void kernel_mul_row(
+ device const float4 * src0,
+ device const float4 * src1,
+ device float4 * dst,
+ constant uint64_t & nb [[buffer(28)]],
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] * src1[tpig % nb];
+}
+
+kernel void kernel_div_row(
+ device const float4 * src0,
+ device const float4 * src1,
+ device float4 * dst,
+ constant uint64_t & nb [[buffer(28)]],
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] / src1[tpig % nb];
+}
+
+kernel void kernel_scale(
+ device const float * src0,
+ device float * dst,
+ constant float & scale,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] * scale;
+}
+
+kernel void kernel_scale_4(
+ device const float4 * src0,
+ device float4 * dst,
+ constant float & scale,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] * scale;
+}
+
+kernel void kernel_relu(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = max(0.0f, src0[tpig]);
+}
+
+kernel void kernel_tanh(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float & x = src0[tpig];
+ dst[tpig] = precise::tanh(x);
+}
+
+constant float GELU_COEF_A = 0.044715f;
+constant float GELU_QUICK_COEF = -1.702f;
+constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
+
+kernel void kernel_gelu(
+ device const float4 * src0,
+ device float4 * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float4 & x = src0[tpig];
+
+ // BEWARE !!!
+ // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
+ // This was observed with Falcon 7B and 40B models
+ //
+ dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
+kernel void kernel_gelu_quick(
+ device const float4 * src0,
+ device float4 * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float4 & x = src0[tpig];
+
+ dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
+}
+
+kernel void kernel_silu(
+ device const float4 * src0,
+ device float4 * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float4 & x = src0[tpig];
+ dst[tpig] = x / (1.0f + exp(-x));
+}
+
+kernel void kernel_sqr(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] * src0[tpig];
+}
+
+kernel void kernel_sum_rows(
+ device const float * src0,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant uint64_t & nb13,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tpig[[thread_position_in_grid]]) {
+ int64_t i3 = tpig.z;
+ int64_t i2 = tpig.y;
+ int64_t i1 = tpig.x;
+
+ if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
+ return;
+ }
+
+ device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
+ device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
+
+ float row_sum = 0;
+
+ for (int64_t i0 = 0; i0 < ne00; i0++) {
+ row_sum += src_row[i0];
+ }
+
+ dst_row[0] = row_sum;
+}
+
+kernel void kernel_soft_max(
+ device const float * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant float & scale,
+ threadgroup float * buf [[threadgroup(0)]],
+ uint tgpig[[threadgroup_position_in_grid]],
+ uint tpitg[[thread_position_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = (tgpig) / (ne02*ne01);
+ const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
+ const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
+
+ device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+ device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
+ device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+ // parallel max
+ float lmax = -INFINITY;
+
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
+ }
+
+ // find the max value in the block
+ float max_val = simd_max(lmax);
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = -INFINITY;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ buf[sgitg] = max_val;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ max_val = buf[tiisg];
+ max_val = simd_max(max_val);
+ }
+
+ // parallel sum
+ float lsum = 0.0f;
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
+ lsum += exp_psrc0;
+ pdst[i00] = exp_psrc0;
+ }
+
+ // This barrier fixes a failing test
+ // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+ threadgroup_barrier(mem_flags::mem_none);
+
+ float sum = simd_sum(lsum);
+
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = 0.0f;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ buf[sgitg] = sum;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ sum = buf[tiisg];
+ sum = simd_sum(sum);
+ }
+
+ const float inv_sum = 1.0f/sum;
+
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ pdst[i00] *= inv_sum;
+ }
+}
+
+kernel void kernel_soft_max_4(
+ device const float * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant float & scale,
+ threadgroup float * buf [[threadgroup(0)]],
+ uint tgpig[[threadgroup_position_in_grid]],
+ uint tpitg[[thread_position_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = (tgpig) / (ne02*ne01);
+ const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
+ const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
+
+ device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+ device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
+ device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+
+ // parallel max
+ float4 lmax4 = -INFINITY;
+
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+ lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
+ }
+
+ const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
+
+ float max_val = simd_max(lmax);
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = -INFINITY;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ buf[sgitg] = max_val;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ max_val = buf[tiisg];
+ max_val = simd_max(max_val);
+ }
+
+ // parallel sum
+ float4 lsum4 = 0.0f;
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+ const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
+ lsum4 += exp_psrc4;
+ pdst4[i00] = exp_psrc4;
+ }
+
+ const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
+
+ // This barrier fixes a failing test
+ // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+ threadgroup_barrier(mem_flags::mem_none);
+
+ float sum = simd_sum(lsum);
+
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = 0.0f;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ buf[sgitg] = sum;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ sum = buf[tiisg];
+ sum = simd_sum(sum);
+ }
+
+ const float inv_sum = 1.0f/sum;
+
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+ pdst4[i00] *= inv_sum;
+ }
+}
+
+kernel void kernel_diag_mask_inf(
+ device const float * src0,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int & n_past,
+ uint3 tpig[[thread_position_in_grid]]) {
+ const int64_t i02 = tpig[2];
+ const int64_t i01 = tpig[1];
+ const int64_t i00 = tpig[0];
+
+ if (i00 > n_past + i01) {
+ dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
+ } else {
+ dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
+ }
+}
+
+kernel void kernel_diag_mask_inf_8(
+ device const float4 * src0,
+ device float4 * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int & n_past,
+ uint3 tpig[[thread_position_in_grid]]) {
+
+ const int64_t i = 2*tpig[0];
+
+ dst[i+0] = src0[i+0];
+ dst[i+1] = src0[i+1];
+ int64_t i4 = 4*i;
+ const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
+ const int64_t i01 = i4/(ne00); i4 -= i01*ne00;
+ const int64_t i00 = i4;
+ for (int k = 3; k >= 0; --k) {
+ if (i00 + 4 + k <= n_past + i01) {
+ break;
+ }
+ dst[i+1][k] = -INFINITY;
+ if (i00 + k > n_past + i01) {
+ dst[i][k] = -INFINITY;
+ }
+ }
+}
+
+kernel void kernel_norm(
+ device const void * src0,
+ device float * dst,
+ constant int64_t & ne00,
+ constant uint64_t & nb01,
+ constant float & eps,
+ threadgroup float * sum [[threadgroup(0)]],
+ uint tgpig[[threadgroup_position_in_grid]],
+ uint tpitg[[thread_position_in_threadgroup]],
+ uint ntg[[threads_per_threadgroup]]) {
+ device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
+ // MEAN
+ // parallel sum
+ sum[tpitg] = 0.0f;
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ sum[tpitg] += x[i00];
+ }
+ // reduce
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ for (uint i = ntg/2; i > 0; i /= 2) {
+ if (tpitg < i) {
+ sum[tpitg] += sum[tpitg + i];
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+ const float mean = sum[0] / ne00;
+
+ // recenter and VARIANCE
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ device float * y = dst + tgpig*ne00;
+ sum[tpitg] = 0.0f;
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ y[i00] = x[i00] - mean;
+ sum[tpitg] += y[i00] * y[i00];
+ }
+
+ // reduce
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ for (uint i = ntg/2; i > 0; i /= 2) {
+ if (tpitg < i) {
+ sum[tpitg] += sum[tpitg + i];
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+ const float variance = sum[0] / ne00;
+
+ const float scale = 1.0f/sqrt(variance + eps);
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ y[i00] = y[i00] * scale;
+ }
+}
+
+kernel void kernel_rms_norm(
+ device const void * src0,
+ device float * dst,
+ constant int64_t & ne00,
+ constant uint64_t & nb01,
+ constant float & eps,
+ threadgroup float * buf [[threadgroup(0)]],
+ uint tgpig[[threadgroup_position_in_grid]],
+ uint tpitg[[thread_position_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint ntg[[threads_per_threadgroup]]) {
+ device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
+
+ float4 sumf = 0;
+ float all_sum = 0;
+
+ // parallel sum
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+ sumf += x[i00] * x[i00];
+ }
+ all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
+ all_sum = simd_sum(all_sum);
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = 0.0f;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ buf[sgitg] = all_sum;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ all_sum = buf[tiisg];
+ all_sum = simd_sum(all_sum);
+ }
+
+ const float mean = all_sum/ne00;
+ const float scale = 1.0f/sqrt(mean + eps);
+
+ device float4 * y = (device float4 *) (dst + tgpig*ne00);
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+ y[i00] = x[i00] * scale;
+ }
+}
+
+kernel void kernel_group_norm(
+ device const float * src0,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int32_t & n_groups,
+ constant float & eps,
+ threadgroup float * buf [[threadgroup(0)]],
+ uint tgpig[[threadgroup_position_in_grid]],
+ uint tpitg[[thread_position_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint ntg[[threads_per_threadgroup]]) {
+ const int64_t ne = ne00*ne01*ne02;
+ const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
+
+ int start = tgpig * gs;
+ int end = start + gs;
+
+ start += tpitg;
+
+ if (end >= ne) {
+ end = ne;
+ }
+
+ float tmp = 0.0f; // partial sum for thread in warp
+
+ for (int j = start; j < end; j += ntg) {
+ tmp += src0[j];
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ tmp = simd_sum(tmp);
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = 0.0f;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ buf[sgitg] = tmp;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ tmp = buf[tiisg];
+ tmp = simd_sum(tmp);
+ }
+
+ const float mean = tmp / gs;
+ tmp = 0.0f;
+
+ for (int j = start; j < end; j += ntg) {
+ float xi = src0[j] - mean;
+ dst[j] = xi;
+ tmp += xi * xi;
+ }
+
+ tmp = simd_sum(tmp);
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = 0.0f;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ buf[sgitg] = tmp;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ tmp = buf[tiisg];
+ tmp = simd_sum(tmp);
+ }
+
+ const float variance = tmp / gs;
+ const float scale = 1.0f/sqrt(variance + eps);
+ for (int j = start; j < end; j += ntg) {
+ dst[j] *= scale;
+ }
+}
+
+// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q4 quants begin (0 or QK4_0/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
+ float d = qb_curr->d;
+
+ float2 acc = 0.f;
+
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
+
+ for (int i = 0; i < 8; i+=2) {
+ acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
+ + yl[i + 1] * (qs[i / 2] & 0x0F00);
+ acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
+ + yl[i + 9] * (qs[i / 2] & 0xF000);
+ }
+ return d * (sumy * -8.f + acc[0] + acc[1]);
+}
+
+// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q4 quants begin (0 or QK4_0/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
+ float d = qb_curr->d;
+ float m = qb_curr->m;
+
+ float2 acc = 0.f;
+
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
+
+ for (int i = 0; i < 8; i+=2) {
+ acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
+ + yl[i + 1] * (qs[i / 2] & 0x0F00);
+ acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
+ + yl[i + 9] * (qs[i / 2] & 0xF000);
+ }
+ return d * (acc[0] + acc[1]) + sumy * m;
+}
+
+// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q5 quants begin (0 or QK5_0/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
+ float d = qb_curr->d;
+
+ float2 acc = 0.f;
+
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
+ const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
+
+ for (int i = 0; i < 8; i+=2) {
+ acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
+ + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
+ acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
+ + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
+ }
+ return d * (sumy * -16.f + acc[0] + acc[1]);
+}
+
+// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q5 quants begin (0 or QK5_1/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
+ float d = qb_curr->d;
+ float m = qb_curr->m;
+
+ float2 acc = 0.f;
+
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
+ const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
+
+ for (int i = 0; i < 8; i+=2) {
+ acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
+ + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
+ acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
+ + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
+ }
+ return d * (acc[0] + acc[1]) + sumy * m;
+}
+
+// putting them in the kernel cause a significant performance penalty
+#define N_DST 4 // each SIMD group works on 4 rows
+#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
+//Note: This is a template, but strictly speaking it only applies to
+// quantizations where the block size is 32. It also does not
+// guard against the number of rows not being divisible by
+// N_DST, so this is another explicit assumption of the implementation.
+template<typename block_q_type, int nr, int nsg, int nw>
+void mul_vec_q_n_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ uint3 tgpig, uint tiisg, uint sgitg) {
+ const int nb = ne00/QK4_0;
+
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * nsg + sgitg) * nr;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+ device const block_q_type * x = (device const block_q_type *) src0 + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ float yl[16]; // src1 vector cache
+ float sumf[nr] = {0.f};
+
+ const int ix = (tiisg/2);
+ const int il = (tiisg%2)*8;
+
+ device const float * yb = y + ix * QK4_0 + il;
+
+ // each thread in a SIMD group deals with half a block.
+ for (int ib = ix; ib < nb; ib += nw/2) {
+ float sumy = 0;
+ for (int i = 0; i < 8; i += 2) {
+ sumy += yb[i] + yb[i+1];
+ yl[i+0] = yb[i+ 0];
+ yl[i+1] = yb[i+ 1]/256.f;
+
+ sumy += yb[i+16] + yb[i+17];
+ yl[i+8] = yb[i+16]/16.f;
+ yl[i+9] = yb[i+17]/4096.f;
+ }
+
+ for (int row = 0; row < nr; row++) {
+ sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il);
+ }
+
+ yb += QK4_0 * 16;
+ }
+
+ for (int row = 0; row < nr; ++row) {
+ const float tot = simd_sum(sumf[row]);
+ if (tiisg == 0 && first_row + row < ne01) {
+ dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot;
+ }
+ }
+}
+
+kernel void kernel_mul_mv_q4_0_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+}
+
+kernel void kernel_mul_mv_q4_1_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+}
+
+kernel void kernel_mul_mv_q5_0_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+}
+
+kernel void kernel_mul_mv_q5_1_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+}
+
+
+#define NB_Q8_0 8
+
+void kernel_mul_mv_q8_0_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ const int nr = N_DST;
+ const int nsg = N_SIMDGROUP;
+ const int nw = N_SIMDWIDTH;
+
+ const int nb = ne00/QK8_0;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * nsg + sgitg) * nr;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+ device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ float yl[NB_Q8_0];
+ float sumf[nr]={0.f};
+
+ const int ix = tiisg/4;
+ const int il = tiisg%4;
+
+ device const float * yb = y + ix * QK8_0 + NB_Q8_0*il;
+
+ // each thread in a SIMD group deals with NB_Q8_0 quants at a time
+ for (int ib = ix; ib < nb; ib += nw/4) {
+ for (int i = 0; i < NB_Q8_0; ++i) {
+ yl[i] = yb[i];
+ }
+
+ for (int row = 0; row < nr; row++) {
+ device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il;
+ float sumq = 0.f;
+ for (int iq = 0; iq < NB_Q8_0; ++iq) {
+ sumq += qs[iq] * yl[iq];
+ }
+ sumf[row] += sumq*x[ib+row*nb].d;
+ }
+
+ yb += NB_Q8_0 * nw;
+ }
+
+ for (int row = 0; row < nr; ++row) {
+ const float tot = simd_sum(sumf[row]);
+ if (tiisg == 0 && first_row + row < ne01) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_q8_0_f32")]]
+kernel void kernel_mul_mv_q8_0_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+}
+
+#define N_F32_F32 4
+
+void kernel_mul_mv_f32_f32_impl(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]]) {
+
+ const int64_t r0 = tgpig.x;
+ const int64_t rb = tgpig.y*N_F32_F32;
+ const int64_t im = tgpig.z;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+
+ device const float * x = (device const float *) (src0 + offset0);
+
+ if (ne00 < 128) {
+ for (int row = 0; row < N_F32_F32; ++row) {
+ int r1 = rb + row;
+ if (r1 >= ne11) {
+ break;
+ }
+
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
+
+ float sumf = 0;
+ for (int i = tiisg; i < ne00; i += 32) {
+ sumf += (float) x[i] * (float) y[i];
+ }
+
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ }
+ }
+ } else {
+ device const float4 * x4 = (device const float4 *)x;
+ for (int row = 0; row < N_F32_F32; ++row) {
+ int r1 = rb + row;
+ if (r1 >= ne11) {
+ break;
+ }
+
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
+ device const float4 * y4 = (device const float4 *) y;
+
+ float sumf = 0;
+ for (int i = tiisg; i < ne00/4; i += 32) {
+ for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
+ }
+
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ }
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_f32_f32")]]
+kernel void kernel_mul_mv_f32_f32(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]]) {
+ kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
+}
+
+#define N_F16_F16 4
+
+kernel void kernel_mul_mv_f16_f16(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]]) {
+
+ const int64_t r0 = tgpig.x;
+ const int64_t rb = tgpig.y*N_F16_F16;
+ const int64_t im = tgpig.z;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+
+ device const half * x = (device const half *) (src0 + offset0);
+
+ if (ne00 < 128) {
+ for (int row = 0; row < N_F16_F16; ++row) {
+ int r1 = rb + row;
+ if (r1 >= ne11) {
+ break;
+ }
+
+ device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
+
+ float sumf = 0;
+ for (int i = tiisg; i < ne00; i += 32) {
+ sumf += (half) x[i] * (half) y[i];
+ }
+
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ }
+ }
+ } else {
+ device const half4 * x4 = (device const half4 *)x;
+ for (int row = 0; row < N_F16_F16; ++row) {
+ int r1 = rb + row;
+ if (r1 >= ne11) {
+ break;
+ }
+
+ device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
+ device const half4 * y4 = (device const half4 *) y;
+
+ float sumf = 0;
+ for (int i = tiisg; i < ne00/4; i += 32) {
+ for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
+ }
+
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ }
+ }
+ }
+}
+
+void kernel_mul_mv_f16_f32_1row_impl(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]]) {
+
+ const int64_t r0 = tgpig.x;
+ const int64_t r1 = tgpig.y;
+ const int64_t im = tgpig.z;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+
+ device const half * x = (device const half *) (src0 + offset0);
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
+
+ float sumf = 0;
+ if (ne00 < 128) {
+ for (int i = tiisg; i < ne00; i += 32) {
+ sumf += (float) x[i] * (float) y[i];
+ }
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ }
+ } else {
+ device const half4 * x4 = (device const half4 *) x;
+ device const float4 * y4 = (device const float4 *) y;
+ for (int i = tiisg; i < ne00/4; i += 32) {
+ for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
+ }
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_f16_f32_1row")]]
+kernel void kernel_mul_mv_f16_f32_1row(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]]) {
+ kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
+}
+
+#define N_F16_F32 4
+
+void kernel_mul_mv_f16_f32_impl(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]]) {
+
+ const int64_t r0 = tgpig.x;
+ const int64_t rb = tgpig.y*N_F16_F32;
+ const int64_t im = tgpig.z;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+
+ device const half * x = (device const half *) (src0 + offset0);
+
+ if (ne00 < 128) {
+ for (int row = 0; row < N_F16_F32; ++row) {
+ int r1 = rb + row;
+ if (r1 >= ne11) {
+ break;
+ }
+
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
+
+ float sumf = 0;
+ for (int i = tiisg; i < ne00; i += 32) {
+ sumf += (float) x[i] * (float) y[i];
+ }
+
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ }
+ }
+ } else {
+ device const half4 * x4 = (device const half4 *)x;
+ for (int row = 0; row < N_F16_F32; ++row) {
+ int r1 = rb + row;
+ if (r1 >= ne11) {
+ break;
+ }
+
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
+ device const float4 * y4 = (device const float4 *) y;
+
+ float sumf = 0;
+ for (int i = tiisg; i < ne00/4; i += 32) {
+ for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
+ }
+
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ }
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_f16_f32")]]
+kernel void kernel_mul_mv_f16_f32(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]]) {
+ kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
+}
+
+// Assumes row size (ne00) is a multiple of 4
+kernel void kernel_mul_mv_f16_f32_l4(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]]) {
+
+ const int nrows = ne11;
+ const int64_t r0 = tgpig.x;
+ const int64_t im = tgpig.z;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+
+ device const half4 * x4 = (device const half4 *) (src0 + offset0);
+
+ for (int r1 = 0; r1 < nrows; ++r1) {
+ device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
+
+ float sumf = 0;
+ for (int i = tiisg; i < ne00/4; i += 32) {
+ for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
+ }
+
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ }
+ }
+}
+
+kernel void kernel_alibi_f32(
+ device const float * src0,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ constant float & m0,
+ constant float & m1,
+ constant int & n_heads_log2_floor,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
+
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ //const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+ const int64_t k = i3*ne3 + i2;
+
+ float m_k;
+ if (k < n_heads_log2_floor) {
+ m_k = pow(m0, k + 1);
+ } else {
+ m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
+ }
+
+ device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
+ device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+ const float src_v = *(device float *)(src_row + i00*nb00);
+ device float * dst_v = (device float *)(dst_row + i00*nb0);
+ *dst_v = i00 * m_k + src_v;
+ }
+}
+
+static float rope_yarn_ramp(const float low, const float high, const int i0) {
+ const float y = (i0 / 2 - low) / max(0.001f, high - low);
+ return 1.0f - min(1.0f, max(0.0f, y));
+}
+
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+static void rope_yarn(
+ float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
+ thread float * cos_theta, thread float * sin_theta
+) {
+ // Get n-d rotational scaling corrected for extrapolation
+ float theta_interp = freq_scale * theta_extrap;
+ float theta = theta_interp;
+ if (ext_factor != 0.0f) {
+ float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
+ theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+ // Get n-d magnitude scaling corrected for interpolation
+ mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
+ }
+ *cos_theta = cos(theta) * mscale;
+ *sin_theta = sin(theta) * mscale;
+}
+
+// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
+// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
+static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
+ return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
+}
+
+static void rope_yarn_corr_dims(
+ int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
+) {
+ // start and end correction dims
+ dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
+ dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
+}
+
+typedef void (rope_t)(
+ device const void * src0,
+ device const int32_t * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ constant int & n_past,
+ constant int & n_dims,
+ constant int & mode,
+ constant int & n_orig_ctx,
+ constant float & freq_base,
+ constant float & freq_scale,
+ constant float & ext_factor,
+ constant float & attn_factor,
+ constant float & beta_fast,
+ constant float & beta_slow,
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint3 tptg[[threads_per_threadgroup]],
+ uint3 tgpig[[threadgroup_position_in_grid]]);
+
+template<typename T>
+kernel void kernel_rope(
+ device const void * src0,
+ device const int32_t * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ constant int & n_past,
+ constant int & n_dims,
+ constant int & mode,
+ constant int & n_orig_ctx,
+ constant float & freq_base,
+ constant float & freq_scale,
+ constant float & ext_factor,
+ constant float & attn_factor,
+ constant float & beta_fast,
+ constant float & beta_slow,
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint3 tptg[[threads_per_threadgroup]],
+ uint3 tgpig[[threadgroup_position_in_grid]]) {
+ const int64_t i3 = tgpig[2];
+ const int64_t i2 = tgpig[1];
+ const int64_t i1 = tgpig[0];
+
+ const bool is_neox = mode & 2;
+
+ float corr_dims[2];
+ rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
+
+ device const int32_t * pos = src1;
+
+ const int64_t p = pos[i2];
+
+ const float theta_0 = (float)p;
+ const float inv_ndims = -1.f/n_dims;
+
+ if (!is_neox) {
+ for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
+
+ const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
+ float cos_theta, sin_theta;
+ rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ const T x0 = src[0];
+ const T x1 = src[1];
+
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
+ dst_data[1] = x0*sin_theta + x1*cos_theta;
+ }
+ } else {
+ for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
+ if (ic < n_dims) {
+ const int64_t ib = 0;
+
+ // simplified from `(ib * n_dims + ic) * inv_ndims`
+ const float cur_rot = inv_ndims*ic - ib;
+
+ const float theta = theta_0 * pow(freq_base, cur_rot);
+ float cos_theta, sin_theta;
+ rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
+
+ const int64_t i0 = ib*n_dims + ic/2;
+
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ const float x0 = src[0];
+ const float x1 = src[n_dims/2];
+
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
+ dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
+ } else {
+ const int64_t i0 = ic;
+
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ dst_data[0] = src[0];
+ dst_data[1] = src[1];
+ }
+ }
+ }
+}
+
+template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
+template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
+
+kernel void kernel_im2col_f16(
+ device const float * x,
+ device half * dst,
+ constant int32_t & ofs0,
+ constant int32_t & ofs1,
+ constant int32_t & IW,
+ constant int32_t & IH,
+ constant int32_t & CHW,
+ constant int32_t & s0,
+ constant int32_t & s1,
+ constant int32_t & p0,
+ constant int32_t & p1,
+ constant int32_t & d0,
+ constant int32_t & d1,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tgpg[[threadgroups_per_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0;
+ const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1;
+
+ const int32_t offset_dst =
+ (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
+ (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
+
+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+ dst[offset_dst] = 0.0f;
+ } else {
+ const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
+ dst[offset_dst] = x[offset_src + iih * IW + iiw];
+ }
+}
+
+kernel void kernel_upscale_f32(
+ device const char * src0,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ constant int32_t & sf,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+
+ const int64_t i3 = tgpig.z;
+ const int64_t i2 = tgpig.y;
+ const int64_t i1 = tgpig.x;
+
+ const int64_t i03 = i3;
+ const int64_t i02 = i2;
+ const int64_t i01 = i1/sf;
+
+ device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
+ device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
+
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ dst_ptr[i0] = src0_ptr[i0/sf];
+ }
+}
+
+kernel void kernel_pad_f32(
+ device const char * src0,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+
+ const int64_t i3 = tgpig.z;
+ const int64_t i2 = tgpig.y;
+ const int64_t i1 = tgpig.x;
+
+ const int64_t i03 = i3;
+ const int64_t i02 = i2;
+ const int64_t i01 = i1;
+
+ device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
+ device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
+
+ if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ if (i0 < ne00) {
+ dst_ptr[i0] = src0_ptr[i0];
+ } else {
+ dst_ptr[i0] = 0.0f;
+ }
+ }
+
+ return;
+ }
+
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ dst_ptr[i0] = 0.0f;
+ }
+}
+
+// bitonic sort implementation following the CUDA kernels as reference
+typedef void (argsort_t)(
+ device const float * x,
+ device int32_t * dst,
+ constant int64_t & ncols,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]]);
+
+template<ggml_sort_order order>
+kernel void kernel_argsort_f32_i32(
+ device const float * x,
+ device int32_t * dst,
+ constant int64_t & ncols,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]]) {
+ // bitonic sort
+ int col = tpitg[0];
+ int row = tgpig[1];
+
+ if (col >= ncols) return;
+
+ device const float * x_row = x + row * ncols;
+ device int32_t * dst_row = dst + row * ncols;
+
+ // initialize indices
+ if (col < ncols) {
+ dst_row[col] = col;
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (int k = 2; k <= ncols; k *= 2) {
+ for (int j = k / 2; j > 0; j /= 2) {
+ int ixj = col ^ j;
+ if (ixj > col) {
+ if ((col & k) == 0) {
+ if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
+ SWAP(dst_row[col], dst_row[ixj]);
+ }
+ } else {
+ if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
+ SWAP(dst_row[col], dst_row[ixj]);
+ }
+ }
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+ }
+}
+
+template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
+template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
+
+kernel void kernel_leaky_relu_f32(
+ device const float * src0,
+ device float * dst,
+ constant float & slope,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
+}
+
+kernel void kernel_cpy_f16_f16(
+ device const half * src0,
+ device half * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
+
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+ device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+ device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+ dst_data[i00] = src[0];
+ }
+}
+
+kernel void kernel_cpy_f16_f32(
+ device const half * src0,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
+
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+ device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+ device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+ dst_data[i00] = src[0];
+ }
+}
+
+kernel void kernel_cpy_f32_f16(
+ device const float * src0,
+ device half * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
+
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+ device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+ dst_data[i00] = src[0];
+ }
+}
+
+kernel void kernel_cpy_f32_f32(
+ device const float * src0,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
+
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+ device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+ dst_data[i00] = src[0];
+ }
+}
+
+kernel void kernel_cpy_f32_q8_0(
+ device const float * src0,
+ device void * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
+
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0;
+
+ device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) {
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+ float amax = 0.0f; // absolute max
+
+ for (int j = 0; j < QK8_0; j++) {
+ const float v = src[j];
+ amax = MAX(amax, fabs(v));
+ }
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dst_data[i00/QK8_0].d = d;
+
+ for (int j = 0; j < QK8_0; ++j) {
+ const float x0 = src[j]*id;
+
+ dst_data[i00/QK8_0].qs[j] = round(x0);
+ }
+ }
+}
+
+kernel void kernel_cpy_f32_q4_0(
+ device const float * src0,
+ device void * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
+
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0;
+
+ device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) {
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+ float amax = 0.0f; // absolute max
+ float max = 0.0f;
+
+ for (int j = 0; j < QK4_0; j++) {
+ const float v = src[j];
+ if (amax < fabs(v)) {
+ amax = fabs(v);
+ max = v;
+ }
+ }
+
+ const float d = max / -8;
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dst_data[i00/QK4_0].d = d;
+
+ for (int j = 0; j < QK4_0/2; ++j) {
+ const float x0 = src[0 + j]*id;
+ const float x1 = src[QK4_0/2 + j]*id;
+
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
+
+ dst_data[i00/QK4_0].qs[j] = xi0;
+ dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
+ }
+ }
+}
+
+kernel void kernel_cpy_f32_q4_1(
+ device const float * src0,
+ device void * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
+
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1;
+
+ device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) {
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+ float min = FLT_MAX;
+ float max = -FLT_MAX;
+
+ for (int j = 0; j < QK4_1; j++) {
+ const float v = src[j];
+ if (min > v) min = v;
+ if (max < v) max = v;
+ }
+
+ const float d = (max - min) / ((1 << 4) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dst_data[i00/QK4_1].d = d;
+ dst_data[i00/QK4_1].m = min;
+
+ for (int j = 0; j < QK4_1/2; ++j) {
+ const float x0 = (src[0 + j] - min)*id;
+ const float x1 = (src[QK4_1/2 + j] - min)*id;
+
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
+
+ dst_data[i00/QK4_1].qs[j] = xi0;
+ dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
+ }
+ }
+}
+
+kernel void kernel_concat(
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant uint64_t & nb13,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+
+ const int64_t i03 = tgpig.z;
+ const int64_t i02 = tgpig.y;
+ const int64_t i01 = tgpig.x;
+
+ const int64_t i13 = i03 % ne13;
+ const int64_t i12 = i02 % ne12;
+ const int64_t i11 = i01 % ne11;
+
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
+
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ if (i02 < ne02) {
+ ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
+ src0_ptr += ntg.x*nb00;
+ } else {
+ ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
+ src1_ptr += ntg.x*nb10;
+ }
+ dst_ptr += ntg.x*nb0;
+ }
+}
+
+//============================================ k-quants ======================================================
+
+#ifndef QK_K
+#define QK_K 256
+#else
+static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64");
+#endif
+
+#if QK_K == 256
+#define K_SCALE_SIZE 12
+#else
+#define K_SCALE_SIZE 4
+#endif
+
+typedef struct {
+ uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
+ uint8_t qs[QK_K/4]; // quants
+ half d; // super-block scale for quantized scales
+ half dmin; // super-block scale for quantized mins
+} block_q2_K;
+// 84 bytes / block
+
+typedef struct {
+ uint8_t hmask[QK_K/8]; // quants - high bit
+ uint8_t qs[QK_K/4]; // quants - low 2 bits
+#if QK_K == 64
+ uint8_t scales[2];
+#else
+ uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
+#endif
+ half d; // super-block scale
+} block_q3_K;
+
+#if QK_K == 64
+typedef struct {
+ half d[2]; // super-block scales/mins
+ uint8_t scales[2];
+ uint8_t qs[QK_K/2]; // 4-bit quants
+} block_q4_K;
+#else
+typedef struct {
+ half d; // super-block scale for quantized scales
+ half dmin; // super-block scale for quantized mins
+ uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
+ uint8_t qs[QK_K/2]; // 4--bit quants
+} block_q4_K;
+#endif
+
+#if QK_K == 64
+typedef struct {
+ half d; // super-block scales/mins
+ int8_t scales[QK_K/16]; // 8-bit block scales
+ uint8_t qh[QK_K/8]; // quants, high bit
+ uint8_t qs[QK_K/2]; // quants, low 4 bits
+} block_q5_K;
+#else
+typedef struct {
+ half d; // super-block scale for quantized scales
+ half dmin; // super-block scale for quantized mins
+ uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
+ uint8_t qh[QK_K/8]; // quants, high bit
+ uint8_t qs[QK_K/2]; // quants, low 4 bits
+} block_q5_K;
+// 176 bytes / block
+#endif
+
+typedef struct {
+ uint8_t ql[QK_K/2]; // quants, lower 4 bits
+ uint8_t qh[QK_K/4]; // quants, upper 2 bits
+ int8_t scales[QK_K/16]; // scales, quantized with 8 bits
+ half d; // super-block scale
+} block_q6_K;
+// 210 bytes / block
+
+//====================================== dot products =========================
+
+void kernel_mul_mv_q2_K_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ const int nb = ne00/QK_K;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int ib_row = first_row * nb;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+ device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ float yl[32];
+ float sumf[N_DST]={0.f}, all_sum;
+
+ const int step = sizeof(block_q2_K) * nb;
+
+#if QK_K == 256
+ const int ix = tiisg/8; // 0...3
+ const int it = tiisg%8; // 0...7
+ const int iq = it/4; // 0 or 1
+ const int ir = it%4; // 0...3
+ const int is = (8*ir)/16;// 0 or 1
+
+ device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
+
+ for (int ib = ix; ib < nb; ib += 4) {
+
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
+ for (int i = 0; i < 8; ++i) {
+ yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
+ yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
+ yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
+ yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
+ }
+
+ device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is;
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
+ device const half * dh = &x[ib].d;
+
+ for (int row = 0; row < N_DST; row++) {
+
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
+ for (int i = 0; i < 8; i += 2) {
+ acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
+ acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
+ acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
+ acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
+ acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
+ acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
+ acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
+ acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
+ }
+ float dall = dh[0];
+ float dmin = dh[1] * 1.f/16.f;
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
+ (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
+ (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
+ (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
+ dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
+
+ qs += step/2;
+ sc += step;
+ dh += step/2;
+ }
+
+ y4 += 4 * QK_K;
+ }
+#else
+ const int ix = tiisg/2; // 0...15
+ const int it = tiisg%2; // 0...1
+
+ device const float * y4 = y + ix * QK_K + 8 * it;
+
+ for (int ib = ix; ib < nb; ib += 16) {
+
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
+ for (int i = 0; i < 8; ++i) {
+ yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
+ yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
+ yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
+ yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
+ }
+
+ device const uint8_t * sc = (device const uint8_t *)x[ib].scales;
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
+ device const half * dh = &x[ib].d;
+
+ for (int row = 0; row < N_DST; row++) {
+
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
+ for (int i = 0; i < 8; i += 2) {
+ acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
+ acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
+ acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
+ acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
+ acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
+ acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
+ acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
+ acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
+ }
+
+ float dall = dh[0];
+ float dmin = dh[1];
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
+ (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
+ (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
+ (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
+ dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
+
+ qs += step/2;
+ sc += step;
+ dh += step/2;
+ }
+
+ y4 += 16 * QK_K;
+ }
+#endif
+
+ for (int row = 0; row < N_DST; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_q2_K_f32")]]
+kernel void kernel_mul_mv_q2_K_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
+#if QK_K == 256
+void kernel_mul_mv_q3_K_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ const int nb = ne00/QK_K;
+
+ const int64_t r0 = tgpig.x;
+ const int64_t r1 = tgpig.y;
+ const int64_t im = tgpig.z;
+
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+ device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ float yl[32];
+
+ //const uint16_t kmask1 = 0x3030;
+ //const uint16_t kmask2 = 0x0f0f;
+
+ const int tid = tiisg/4;
+ const int ix = tiisg%4;
+ const int ip = tid/4; // 0 or 1
+ const int il = 2*((tid%4)/2); // 0 or 2
+ const int ir = tid%2;
+ const int n = 8;
+ const int l0 = n*ir;
+
+ // One would think that the Metal compiler would figure out that ip and il can only have
+ // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
+ // with these two tales.
+ //
+ // Possible masks for the high bit
+ const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0
+ {0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2
+ {0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0
+ {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
+
+ // Possible masks for the low 2 bits
+ const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
+
+ const ushort4 hm = mm[2*ip + il/2];
+
+ const int shift = 2*il;
+ const float v1 = il == 0 ? 4.f : 64.f;
+ const float v2 = 4.f * v1;
+
+ const uint16_t s_shift1 = 4*ip;
+ const uint16_t s_shift2 = s_shift1 + il;
+
+ const int q_offset = 32*ip + l0;
+ const int y_offset = 128*ip + 32*il + l0;
+
+ const int step = sizeof(block_q3_K) * nb / 2;
+
+ device const float * y1 = yy + ix*QK_K + y_offset;
+
+ uint32_t scales32, aux32;
+ thread uint16_t * scales16 = (thread uint16_t *)&scales32;
+ thread const int8_t * scales = (thread const int8_t *)&scales32;
+
+ float sumf1[2] = {0.f};
+ float sumf2[2] = {0.f};
+ for (int i = ix; i < nb; i += 4) {
+
+ for (int l = 0; l < 8; ++l) {
+ yl[l+ 0] = y1[l+ 0];
+ yl[l+ 8] = y1[l+16];
+ yl[l+16] = y1[l+32];
+ yl[l+24] = y1[l+48];
+ }
+
+ device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
+ device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
+ device const uint16_t * a = (device const uint16_t *)(x[i].scales);
+ device const half * dh = &x[i].d;
+
+ for (int row = 0; row < 2; ++row) {
+
+ const float d_all = (float)dh[0];
+
+ scales16[0] = a[4];
+ scales16[1] = a[5];
+ aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
+ scales16[0] = a[il+0];
+ scales16[1] = a[il+1];
+ scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
+
+ float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
+ for (int l = 0; l < n; l += 2) {
+ const int32_t qs = q[l/2];
+ s1 += yl[l+0] * (qs & qm[il/2][0]);
+ s2 += yl[l+1] * (qs & qm[il/2][1]);
+ s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
+ s4 += yl[l+16] * (qs & qm[il/2][2]);
+ s5 += yl[l+17] * (qs & qm[il/2][3]);
+ s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
+ }
+ float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
+ float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
+ sumf1[row] += d1 * (scales[0] - 32);
+ sumf2[row] += d2 * (scales[2] - 32);
+
+ s1 = s2 = s3 = s4 = s5 = s6 = 0;
+ for (int l = 0; l < n; l += 2) {
+ const int32_t qs = q[l/2+8];
+ s1 += yl[l+8] * (qs & qm[il/2][0]);
+ s2 += yl[l+9] * (qs & qm[il/2][1]);
+ s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
+ s4 += yl[l+24] * (qs & qm[il/2][2]);
+ s5 += yl[l+25] * (qs & qm[il/2][3]);
+ s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
+ }
+ d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
+ d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
+ sumf1[row] += d1 * (scales[1] - 32);
+ sumf2[row] += d2 * (scales[3] - 32);
+
+ q += step;
+ h += step;
+ a += step;
+ dh += step;
+
+ }
+
+ y1 += 4 * QK_K;
+
+ }
+
+ for (int row = 0; row < 2; ++row) {
+ const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
+ sumf1[row] = simd_sum(sumf);
+ }
+ if (tiisg == 0) {
+ for (int row = 0; row < 2; ++row) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row];
+ }
+ }
+}
+#else
+void kernel_mul_mv_q3_K_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ const int nb = ne00/QK_K;
+
+ const int64_t r0 = tgpig.x;
+ const int64_t r1 = tgpig.y;
+ const int64_t im = tgpig.z;
+
+ const int row = 2 * r0 + sgitg;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+ device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ const int ix = tiisg/4;
+ const int il = 4 * (tiisg%4);// 0, 4, 8, 12
+ const int iq = il/8; // 0, 0, 1, 1
+ const int in = il%8; // 0, 4, 0, 4
+
+ float2 sum = {0.f, 0.f};
+
+ for (int i = ix; i < nb; i += 8) {
+
+ const float d_all = (float)(x[i].d);
+
+ device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
+ device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
+ device const uint16_t * s = (device const uint16_t *)(x[i].scales);
+ device const float * y = yy + i * QK_K + il;
+
+ const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
+ const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
+ const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
+ const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
+
+ for (int l = 0; l < 4; l += 2) {
+ const uint16_t hm = h[l/2] >> iq;
+ sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
+ + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
+ + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
+ + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
+ sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
+ + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
+ + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
+ + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
+ }
+
+ }
+ const float sumf = sum[0] + sum[1] * 1.f/256.f;
+
+ const float tot = simd_sum(sumf);
+ if (tiisg == 0) {
+ dst[r1*ne0 + im*ne0*ne1 + row] = tot;
+ }
+
+}
+#endif
+
+[[host_name("kernel_mul_mv_q3_K_f32")]]
+kernel void kernel_mul_mv_q3_K_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
+#if QK_K == 256
+void kernel_mul_mv_q4_K_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ const uint16_t kmask1 = 0x3f3f;
+ const uint16_t kmask2 = 0x0f0f;
+ const uint16_t kmask3 = 0xc0c0;
+
+ const int ix = tiisg/8; // 0...3
+ const int it = tiisg%8; // 0...7
+ const int iq = it/4; // 0 or 1
+ const int ir = it%4; // 0...3
+
+ const int nb = ne00/QK_K;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+ //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int first_row = r0 * N_DST;
+ const int ib_row = first_row * nb;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+ device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ float yl[16];
+ float yh[16];
+ float sumf[N_DST]={0.f}, all_sum;
+
+ const int step = sizeof(block_q4_K) * nb / 2;
+
+ device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
+
+ uint16_t sc16[4];
+ thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
+
+ for (int ib = ix; ib < nb; ib += 4) {
+
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
+ for (int i = 0; i < 8; ++i) {
+ yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
+ yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
+ yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
+ yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
+ }
+
+ device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq;
+ device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
+ device const half * dh = &x[ib].d;
+
+ for (int row = 0; row < N_DST; row++) {
+
+ sc16[0] = sc[0] & kmask1;
+ sc16[1] = sc[2] & kmask1;
+ sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
+ sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
+
+ device const uint16_t * q2 = q1 + 32;
+
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
+ for (int i = 0; i < 8; i += 2) {
+ acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
+ acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
+ acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
+ acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
+ acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
+ acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
+ acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
+ acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
+ }
+
+ float dall = dh[0];
+ float dmin = dh[1];
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
+ (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
+ (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
+ (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
+ dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
+
+ q1 += step;
+ sc += step;
+ dh += step;
+ }
+
+ y4 += 4 * QK_K;
+ }
+
+ for (int row = 0; row < N_DST; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+ }
+ }
+}
+#else
+void kernel_mul_mv_q4_K_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ const int ix = tiisg/4; // 0...7
+ const int it = tiisg%4; // 0...3
+
+ const int nb = ne00/QK_K;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+ const int first_row = r0 * N_DST;
+ const int ib_row = first_row * nb;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+ device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ float yl[8];
+ float yh[8];
+ float sumf[N_DST]={0.f}, all_sum;
+
+ const int step = sizeof(block_q4_K) * nb / 2;
+
+ device const float * y4 = y + ix * QK_K + 8 * it;
+
+ uint16_t sc16[4];
+
+ for (int ib = ix; ib < nb; ib += 8) {
+
+ float2 sumy = {0.f, 0.f};
+ for (int i = 0; i < 8; ++i) {
+ yl[i] = y4[i+ 0]; sumy[0] += yl[i];
+ yh[i] = y4[i+32]; sumy[1] += yh[i];
+ }
+
+ device const uint16_t * sc = (device const uint16_t *)x[ib].scales;
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
+ device const half * dh = x[ib].d;
+
+ for (int row = 0; row < N_DST; row++) {
+
+ sc16[0] = sc[0] & 0x000f;
+ sc16[1] = sc[0] & 0x0f00;
+ sc16[2] = sc[0] & 0x00f0;
+ sc16[3] = sc[0] & 0xf000;
+
+ float2 acc1 = {0.f, 0.f};
+ float2 acc2 = {0.f, 0.f};
+ for (int i = 0; i < 8; i += 2) {
+ acc1[0] += yl[i+0] * (qs[i/2] & 0x000F);
+ acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00);
+ acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0);
+ acc2[1] += yh[i+1] * (qs[i/2] & 0xF000);
+ }
+
+ float dall = dh[0];
+ float dmin = dh[1];
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] +
+ (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) -
+ dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f);
+
+ qs += step;
+ sc += step;
+ dh += step;
+ }
+
+ y4 += 8 * QK_K;
+ }
+
+ for (int row = 0; row < N_DST; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+ }
+ }
+}
+#endif
+
+[[host_name("kernel_mul_mv_q4_K_f32")]]
+kernel void kernel_mul_mv_q4_K_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
+void kernel_mul_mv_q5_K_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ const int nb = ne00/QK_K;
+
+ const int64_t r0 = tgpig.x;
+ const int64_t r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+ device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ float sumf[2]={0.f};
+
+ const int step = sizeof(block_q5_K) * nb;
+
+#if QK_K == 256
+#
+ float yl[16], yh[16];
+
+ const uint16_t kmask1 = 0x3f3f;
+ const uint16_t kmask2 = 0x0f0f;
+ const uint16_t kmask3 = 0xc0c0;
+
+ const int tid = tiisg/4;
+ const int ix = tiisg%4;
+ const int iq = tid/4;
+ const int ir = tid%4;
+ const int n = 8;
+
+ const int l0 = n*ir;
+ const int q_offset = 32*iq + l0;
+ const int y_offset = 64*iq + l0;
+
+ const uint8_t hm1 = 1u << (2*iq);
+ const uint8_t hm2 = hm1 << 1;
+ const uint8_t hm3 = hm1 << 4;
+ const uint8_t hm4 = hm2 << 4;
+
+ uint16_t sc16[4];
+ thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
+
+ device const float * y1 = yy + ix*QK_K + y_offset;
+
+ for (int i = ix; i < nb; i += 4) {
+
+ device const uint8_t * q1 = x[i].qs + q_offset;
+ device const uint8_t * qh = x[i].qh + l0;
+ device const half * dh = &x[i].d;
+ device const uint16_t * a = (device const uint16_t *)x[i].scales + iq;
+
+ device const float * y2 = y1 + 128;
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
+ for (int l = 0; l < 8; ++l) {
+ yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
+ yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
+ yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
+ yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
+ }
+
+ for (int row = 0; row < 2; ++row) {
+
+ device const uint8_t * q2 = q1 + 64;
+
+ sc16[0] = a[0] & kmask1;
+ sc16[1] = a[2] & kmask1;
+ sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
+ sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
+
+ float4 acc1 = {0.f};
+ float4 acc2 = {0.f};
+ for (int l = 0; l < n; ++l) {
+ uint8_t h = qh[l];
+ acc1[0] += yl[l+0] * (q1[l] & 0x0F);
+ acc1[1] += yl[l+8] * (q1[l] & 0xF0);
+ acc1[2] += yh[l+0] * (q2[l] & 0x0F);
+ acc1[3] += yh[l+8] * (q2[l] & 0xF0);
+ acc2[0] += h & hm1 ? yl[l+0] : 0.f;
+ acc2[1] += h & hm2 ? yl[l+8] : 0.f;
+ acc2[2] += h & hm3 ? yh[l+0] : 0.f;
+ acc2[3] += h & hm4 ? yh[l+8] : 0.f;
+ }
+ const float dall = dh[0];
+ const float dmin = dh[1];
+ sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
+ sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
+ sc8[4] * (acc1[2] + 16.f*acc2[2]) +
+ sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
+ dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
+
+ q1 += step;
+ qh += step;
+ dh += step/2;
+ a += step/2;
+
+ }
+
+ y1 += 4 * QK_K;
+
+ }
+#else
+ float yl[8], yh[8];
+
+ const int il = 4 * (tiisg/8); // 0, 4, 8, 12
+ const int ix = tiisg%8;
+ const int iq = il/8; // 0, 0, 1, 1
+ const int in = il%8; // 0, 4, 0, 4
+
+ device const float * y = yy + ix*QK_K + il;
+
+ for (int i = ix; i < nb; i += 8) {
+
+ for (int l = 0; l < 4; ++l) {
+ yl[l+0] = y[l+ 0];
+ yl[l+4] = y[l+16];
+ yh[l+0] = y[l+32];
+ yh[l+4] = y[l+48];
+ }
+
+ device const half * dh = &x[i].d;
+ device const uint8_t * q = x[i].qs + il;
+ device const uint8_t * h = x[i].qh + in;
+ device const int8_t * s = x[i].scales;
+
+ for (int row = 0; row < 2; ++row) {
+
+ const float d = dh[0];
+
+ float2 acc = {0.f, 0.f};
+ for (int l = 0; l < 4; ++l) {
+ const uint8_t hl = h[l] >> iq;
+ acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
+ + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
+ acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
+ + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256));
+ }
+ sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]);
+
+ q += step;
+ h += step;
+ s += step;
+ dh += step/2;
+
+ }
+
+ y += 8 * QK_K;
+ }
+#endif
+
+ for (int row = 0; row < 2; ++row) {
+ const float tot = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_q5_K_f32")]]
+kernel void kernel_mul_mv_q5_K_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
+void kernel_mul_mv_q6_K_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ const uint8_t kmask1 = 0x03;
+ const uint8_t kmask2 = 0x0C;
+ const uint8_t kmask3 = 0x30;
+ const uint8_t kmask4 = 0xC0;
+
+ const int nb = ne00/QK_K;
+
+ const int64_t r0 = tgpig.x;
+ const int64_t r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int row = 2 * r0 + sgitg;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+ device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ float sumf = 0;
+
+#if QK_K == 256
+ const int tid = tiisg/2;
+ const int ix = tiisg%2;
+ const int ip = tid/8; // 0 or 1
+ const int il = tid%8;
+ const int n = 4;
+ const int l0 = n*il;
+ const int is = 8*ip + l0/16;
+
+ const int y_offset = 128*ip + l0;
+ const int q_offset_l = 64*ip + l0;
+ const int q_offset_h = 32*ip + l0;
+
+ for (int i = ix; i < nb; i += 2) {
+
+ device const uint8_t * q1 = x[i].ql + q_offset_l;
+ device const uint8_t * q2 = q1 + 32;
+ device const uint8_t * qh = x[i].qh + q_offset_h;
+ device const int8_t * sc = x[i].scales + is;
+
+ device const float * y = yy + i * QK_K + y_offset;
+
+ const float dall = x[i].d;
+
+ float4 sums = {0.f, 0.f, 0.f, 0.f};
+ for (int l = 0; l < n; ++l) {
+ sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
+ sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
+ sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
+ sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
+ }
+
+ sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
+
+ }
+
+#else
+ const int ix = tiisg/4;
+ const int il = 4*(tiisg%4);
+
+ for (int i = ix; i < nb; i += 8) {
+ device const float * y = yy + i * QK_K + il;
+ device const uint8_t * ql = x[i].ql + il;
+ device const uint8_t * qh = x[i].qh + il;
+ device const int8_t * s = x[i].scales;
+
+ const float d = x[i].d;
+
+ float4 sums = {0.f, 0.f, 0.f, 0.f};
+ for (int l = 0; l < 4; ++l) {
+ sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
+ sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
+ sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32);
+ sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
+ }
+ sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
+ }
+
+#endif
+
+ const float tot = simd_sum(sumf);
+ if (tiisg == 0) {
+ dst[r1*ne0 + im*ne0*ne1 + row] = tot;
+ }
+}
+
+[[host_name("kernel_mul_mv_q6_K_f32")]]
+kernel void kernel_mul_mv_q6_K_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
+//============================= templates and their specializations =============================
+
+// NOTE: this is not dequantizing - we are simply fitting the template
+template <typename type4x4>
+void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
+ float4x4 temp = *(((device float4x4 *)src));
+ for (int i = 0; i < 16; i++){
+ reg[i/4][i%4] = temp[i/4][i%4];
+ }
+}
+
+template <typename type4x4>
+void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
+ half4x4 temp = *(((device half4x4 *)src));
+ for (int i = 0; i < 16; i++){
+ reg[i/4][i%4] = temp[i/4][i%4];
+ }
+}
+
+template <typename type4x4>
+void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 1);
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
+ const float d2 = d1 / 256.f;
+ const float md = -8.h * xb->d;
+ const ushort mask0 = il ? 0x00F0 : 0x000F;
+ const ushort mask1 = mask0 << 8;
+
+ for (int i=0;i<8;i++) {
+ reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
+ reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
+ }
+}
+
+template <typename type4x4>
+void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 2);
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
+ const float d2 = d1 / 256.f;
+ const float m = xb->m;
+ const ushort mask0 = il ? 0x00F0 : 0x000F;
+ const ushort mask1 = mask0 << 8;
+
+ for (int i=0;i<8;i++) {
+ reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
+ reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
+ }
+}
+
+template <typename type4x4>
+void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 3);
+ const float d = xb->d;
+ const float md = -16.h * xb->d;
+ const ushort mask = il ? 0x00F0 : 0x000F;
+
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
+
+ const int x_mv = il ? 4 : 0;
+
+ const int gh_mv = il ? 12 : 0;
+ const int gh_bk = il ? 0 : 4;
+
+ for (int i = 0; i < 8; i++) {
+ // extract the 5-th bits for x0 and x1
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
+
+ // combine the 4-bits from qs with the 5th bit
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
+
+ reg[i/2][2*(i%2)+0] = d * x0 + md;
+ reg[i/2][2*(i%2)+1] = d * x1 + md;
+ }
+}
+
+template <typename type4x4>
+void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 4);
+ const float d = xb->d;
+ const float m = xb->m;
+ const ushort mask = il ? 0x00F0 : 0x000F;
+
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
+
+ const int x_mv = il ? 4 : 0;
+
+ const int gh_mv = il ? 12 : 0;
+ const int gh_bk = il ? 0 : 4;
+
+ for (int i = 0; i < 8; i++) {
+ // extract the 5-th bits for x0 and x1
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
+
+ // combine the 4-bits from qs with the 5th bit
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
+
+ reg[i/2][2*(i%2)+0] = d * x0 + m;
+ reg[i/2][2*(i%2)+1] = d * x1 + m;
+ }
+}
+
+template <typename type4x4>
+void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
+ device const int8_t * qs = ((device const int8_t *)xb->qs);
+ const half d = xb->d;
+
+ for (int i = 0; i < 16; i++) {
+ reg[i/4][i%4] = (qs[i + 16*il] * d);
+ }
+}
+
+template <typename type4x4>
+void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
+ const float d = xb->d;
+ const float min = xb->dmin;
+ device const uint8_t * q = (device const uint8_t *)xb->qs;
+ float dl, ml;
+ uint8_t sc = xb->scales[il];
+
+#if QK_K == 256
+ q = q + 32*(il/8) + 16*(il&1);
+ il = (il/2)%4;
+#endif
+ half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
+ uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+ dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
+ for (int i = 0; i < 16; ++i) {
+ reg[i/4][i%4] = dl * (q[i] & mask) - ml;
+ }
+}
+
+template <typename type4x4>
+void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
+ const half d_all = xb->d;
+ device const uint8_t * q = (device const uint8_t *)xb->qs;
+ device const uint8_t * h = (device const uint8_t *)xb->hmask;
+ device const int8_t * scales = (device const int8_t *)xb->scales;
+
+#if QK_K == 256
+ q = q + 32 * (il/8) + 16 * (il&1);
+ h = h + 16 * (il&1);
+ uint8_t m = 1 << (il/2);
+ uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
+ ((il/4)>0 ? 12 : 3);
+ uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
+ uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
+ int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
+ : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
+ half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
+ const half ml = 4.h * dl;
+
+ il = (il/2) & 3;
+ const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
+ const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+ dl *= coef;
+
+ for (int i = 0; i < 16; ++i) {
+ reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
+ }
+#else
+ float kcoef = il&1 ? 1.f/16.f : 1.f;
+ uint16_t kmask = il&1 ? 0xF0 : 0x0F;
+ float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8);
+ float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
+ uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+ uint8_t m = 1<<(il*2);
+ for (int i = 0; i < 16; ++i) {
+ reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef));
+ }
+#endif
+}
+
+static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
+ return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
+ : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
+}
+
+template <typename type4x4>
+void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
+ device const uchar * q = xb->qs;
+
+#if QK_K == 256
+ short is = (il/4) * 2;
+ q = q + (il/4) * 32 + 16 * (il&1);
+ il = il & 3;
+ const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
+ const float d = il < 2 ? xb->d : xb->d / 16.h;
+ const float min = xb->dmin;
+ const float dl = d * sc[0];
+ const float ml = min * sc[1];
+#else
+ q = q + 16 * (il&1);
+ device const uint8_t * s = xb->scales;
+ device const half2 * dh = (device const half2 *)xb->d;
+ const float2 d = (float2)dh[0];
+ const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
+ const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
+#endif
+ const ushort mask = il<2 ? 0x0F : 0xF0;
+ for (int i = 0; i < 16; ++i) {
+ reg[i/4][i%4] = dl * (q[i] & mask) - ml;
+ }
+}
+
+template <typename type4x4>
+void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
+ device const uint8_t * q = xb->qs;
+ device const uint8_t * qh = xb->qh;
+
+#if QK_K == 256
+ short is = (il/4) * 2;
+ q = q + 32 * (il/4) + 16 * (il&1);
+ qh = qh + 16 * (il&1);
+ uint8_t ul = 1 << (il/2);
+ il = il & 3;
+ const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
+ const float d = il < 2 ? xb->d : xb->d / 16.h;
+ const float min = xb->dmin;
+ const float dl = d * sc[0];
+ const float ml = min * sc[1];
+
+ const ushort mask = il<2 ? 0x0F : 0xF0;
+ const float qh_val = il<2 ? 16.f : 256.f;
+ for (int i = 0; i < 16; ++i) {
+ reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
+ }
+#else
+ q = q + 16 * (il&1);
+ device const int8_t * s = xb->scales;
+ const float dl = xb->d * s[il];
+ uint8_t m = 1<<(il*2);
+ const float coef = il<2 ? 1.f : 1.f/16.f;
+ const ushort mask = il<2 ? 0x0F : 0xF0;
+ for (int i = 0; i < 16; ++i) {
+ reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef));
+ }
+#endif
+}
+
+template <typename type4x4>
+void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
+ const half d_all = xb->d;
+ device const uint8_t * ql = (device const uint8_t *)xb->ql;
+ device const uint8_t * qh = (device const uint8_t *)xb->qh;
+ device const int8_t * scales = (device const int8_t *)xb->scales;
+
+#if QK_K == 256
+ ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
+ qh = qh + 32*(il/8) + 16*(il&1);
+ half sc = scales[(il%2) + 2 * ((il/2))];
+ il = (il/2) & 3;
+#else
+ ql = ql + 16 * (il&1);
+ half sc = scales[il];
+#endif
+ const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+ const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
+ const half coef = il>1 ? 1.f/16.h : 1.h;
+ const half ml = d_all * sc * 32.h;
+ const half dl = d_all * sc * coef;
+ for (int i = 0; i < 16; ++i) {
+ const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
+ : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
+ reg[i/4][i%4] = dl * q - ml;
+ }
+}
+
+template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
+kernel void kernel_get_rows(
+ device const void * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint3 tptg [[threads_per_threadgroup]]) {
+ //const int64_t i = tgpig;
+ //const int64_t r = ((device int32_t *) src1)[i];
+
+ const int64_t i10 = tgpig.x;
+ const int64_t i11 = tgpig.y;
+
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+ const int64_t i02 = i11;
+
+ for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
+ float4x4 temp;
+ dequantize_func(
+ ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
+ *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
+ }
+}
+
+kernel void kernel_get_rows_f32(
+ device const void * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint3 tptg [[threads_per_threadgroup]]) {
+ const int64_t i10 = tgpig.x;
+ const int64_t i11 = tgpig.y;
+
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+ const int64_t i02 = i11;
+
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
+ ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
+ ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
+ }
+}
+
+kernel void kernel_get_rows_f16(
+ device const void * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint3 tptg [[threads_per_threadgroup]]) {
+ const int64_t i10 = tgpig.x;
+ const int64_t i11 = tgpig.y;
+
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+ const int64_t i02 = i11;
+
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
+ ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
+ ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
+ }
+}
+
+#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
+#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
+#define BLOCK_SIZE_K 32
+#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
+#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
+#define THREAD_PER_BLOCK 128
+#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
+#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
+#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
+#define SG_MAT_ROW 8
+
+// each block_q contains 16*nl weights
+template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
+void kernel_mul_mm_impl(device const uchar * src0,
+ device const uchar * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne02,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ threadgroup half * sa = (threadgroup half *)(shared_memory);
+ threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
+
+ const uint r0 = tgpig.y;
+ const uint r1 = tgpig.x;
+ const uint im = tgpig.z;
+
+ // if this block is of 64x32 shape or smaller
+ short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
+ short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
+
+ // a thread shouldn't load data outside of the matrix
+ short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
+ short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
+
+ simdgroup_half8x8 ma[4];
+ simdgroup_float8x8 mb[2];
+ simdgroup_float8x8 c_res[8];
+ for (int i = 0; i < 8; i++){
+ c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
+ }
+
+ short il = (tiitg % THREAD_PER_ROW);
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
+ ushort offset1 = il/nl;
+
+ device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
+ device const float * y = (device const float *)(src1
+ + nb12 * im
+ + nb11 * (r1 * BLOCK_SIZE_N + thread_col)
+ + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
+
+ for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
+ // load data and store to threadgroup memory
+ half4x4 temp_a;
+ dequantize_func(x, il, temp_a);
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ #pragma unroll(16)
+ for (int i = 0; i < 16; i++) {
+ *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
+ + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
+ + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
+ }
+
+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
+
+ il = (il + 2 < nl) ? il + 2 : il % 2;
+ x = (il < 2) ? x + (2+nl-1)/nl : x;
+ y += BLOCK_SIZE_K;
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // load matrices from threadgroup memory and conduct outer products
+ threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
+ threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
+
+ #pragma unroll(4)
+ for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
+ #pragma unroll(4)
+ for (int i = 0; i < 4; i++) {
+ simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
+ }
+ simdgroup_barrier(mem_flags::mem_none);
+ #pragma unroll(2)
+ for (int i = 0; i < 2; i++) {
+ simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
+ }
+
+ lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
+ lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
+
+ #pragma unroll(8)
+ for (int i = 0; i < 8; i++){
+ simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
+ }
+ }
+ }
+
+ if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
+ device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
+ + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
+ for (int i = 0; i < 8; i++) {
+ simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
+ }
+ } else {
+ // block is smaller than 64x32, we should avoid writing data outside of the matrix
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
+ + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
+ for (int i = 0; i < 8; i++) {
+ simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
+ if (sgitg == 0) {
+ for (int i = 0; i < n_rows; i++) {
+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
+ *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
+ }
+ }
+ }
+ }
+}
+
+// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids
+template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
+void kernel_mul_mm_id_impl(
+ device const uchar * src0,
+ device const uchar * src1,
+ thread short * src1ids,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne02,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ int64_t ne1,
+ constant uint & r2,
+ constant uint & r3,
+ threadgroup uchar * shared_memory,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ threadgroup half * sa = (threadgroup half *)(shared_memory);
+ threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
+
+ const uint r0 = tgpig.y;
+ const uint r1 = tgpig.x;
+ const uint im = tgpig.z;
+
+ if (r1 * BLOCK_SIZE_N >= ne1) return;
+
+ // if this block is of 64x32 shape or smaller
+ short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
+ short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
+
+ // a thread shouldn't load data outside of the matrix
+ short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
+ short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
+
+ simdgroup_half8x8 ma[4];
+ simdgroup_float8x8 mb[2];
+ simdgroup_float8x8 c_res[8];
+ for (int i = 0; i < 8; i++){
+ c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
+ }
+
+ short il = (tiitg % THREAD_PER_ROW);
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
+ ushort offset1 = il/nl;
+
+ device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
+ device const float * y = (device const float *)(src1
+ + nb12 * im
+ + nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col]
+ + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
+
+ for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
+ // load data and store to threadgroup memory
+ half4x4 temp_a;
+ dequantize_func(x, il, temp_a);
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (int i = 0; i < 16; i++) {
+ *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
+ + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
+ + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
+ }
+
+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
+
+ il = (il + 2 < nl) ? il + 2 : il % 2;
+ x = (il < 2) ? x + (2+nl-1)/nl : x;
+ y += BLOCK_SIZE_K;
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // load matrices from threadgroup memory and conduct outer products
+ threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
+ threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
+
+ for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
+ for (int i = 0; i < 4; i++) {
+ simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
+ }
+ simdgroup_barrier(mem_flags::mem_none);
+ for (int i = 0; i < 2; i++) {
+ simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
+ }
+
+ lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
+ lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
+
+ for (int i = 0; i < 8; i++){
+ simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
+ }
+ }
+ }
+
+ {
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
+ + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
+ for (int i = 0; i < 8; i++) {
+ simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0;
+ if (sgitg == 0) {
+ for (int i = 0; i < n_rows; i++) {
+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
+ *(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
+ }
+ }
+ }
+ }
+}
+
+template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
+kernel void kernel_mul_mm(device const uchar * src0,
+ device const uchar * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne02,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ kernel_mul_mm_impl<block_q, nl, dequantize_func>(
+ src0,
+ src1,
+ dst,
+ ne00,
+ ne02,
+ nb01,
+ nb02,
+ ne12,
+ nb10,
+ nb11,
+ nb12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ shared_memory,
+ tgpig,
+ tiitg,
+ sgitg);
+}
+
+template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
+kernel void kernel_mul_mm_id(
+ device const uchar * ids,
+ device const uchar * src1,
+ device float * dst,
+ constant uint64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne02,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const uchar * src00,
+ device const uchar * src01,
+ device const uchar * src02,
+ device const uchar * src03,
+ device const uchar * src04,
+ device const uchar * src05,
+ device const uchar * src06,
+ device const uchar * src07,
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ // expert id
+ const int32_t id = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ // row indices of src1 for expert id
+ int64_t _ne1 = 0;
+ short src1ids[512];
+
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
+ if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
+ src1ids[_ne1++] = i1;
+ }
+ }
+
+ kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
+ src0s[id],
+ src1,
+ src1ids,
+ dst,
+ ne00,
+ ne02,
+ nb01,
+ nb02,
+ ne12,
+ nb10,
+ nb11,
+ nb12,
+ ne0,
+ _ne1,
+ r2,
+ r3,
+ shared_memory,
+ tgpig,
+ tiitg,
+ sgitg);
+}
+
+#if QK_K == 256
+#define QK_NL 16
+#else
+#define QK_NL 4
+#endif
+
+//
+// get rows
+//
+
+typedef void (get_rows_t)(
+ device const void * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ uint3, uint, uint3);
+
+//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
+//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
+template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
+template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
+template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
+template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
+template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
+template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
+template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
+template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
+template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
+template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
+
+//
+// matrix-matrix multiplication
+//
+
+typedef void (mat_mm_t)(
+ device const uchar * src0,
+ device const uchar * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne02,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ threadgroup uchar *,
+ uint3, uint, uint);
+
+template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
+template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
+template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
+template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
+template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
+template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
+template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
+template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
+template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
+template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
+template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
+template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
+
+//
+// indirect matrix-matrix multiplication
+//
+
+typedef void (mat_mm_id_t)(
+ device const uchar * ids,
+ device const uchar * src1,
+ device float * dst,
+ constant uint64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne02,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const uchar * src00,
+ device const uchar * src01,
+ device const uchar * src02,
+ device const uchar * src03,
+ device const uchar * src04,
+ device const uchar * src05,
+ device const uchar * src06,
+ device const uchar * src07,
+ threadgroup uchar *,
+ uint3, uint, uint);
+
+template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
+template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
+template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
+template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
+template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
+template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
+template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
+template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
+template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
+template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
+template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
+template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
+
+//
+// matrix-vector multiplication
+//
+
+[[host_name("kernel_mul_mv_id_f32_f32")]]
+kernel void kernel_mul_mv_id_f32_f32(
+ device const char * ids,
+ device const char * src1,
+ device float * dst,
+ constant uint64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_f32_f32_impl(
+ src0[id],
+ src1 + bid*nb11,
+ dst + bid*ne0,
+ ne00,
+ ne01,
+ ne02,
+ nb00,
+ nb01,
+ nb02,
+ ne10,
+ ne11,
+ ne12,
+ nb10,
+ nb11,
+ nb12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg);
+}
+
+[[host_name("kernel_mul_mv_id_f16_f32")]]
+kernel void kernel_mul_mv_id_f16_f32(
+ device const char * ids,
+ device const char * src1,
+ device float * dst,
+ constant uint64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_f16_f32_impl(
+ src0[id],
+ src1 + bid*nb11,
+ dst + bid*ne0,
+ ne00,
+ ne01,
+ ne02,
+ nb00,
+ nb01,
+ nb02,
+ ne10,
+ ne11,
+ ne12,
+ nb10,
+ nb11,
+ nb12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg);
+}
+
+[[host_name("kernel_mul_mv_id_q8_0_f32")]]
+kernel void kernel_mul_mv_id_q8_0_f32(
+ device const char * ids,
+ device const char * src1,
+ device float * dst,
+ constant uint64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_q8_0_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ dst + bid*ne0,
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q4_0_f32")]]
+kernel void kernel_mul_mv_id_q4_0_f32(
+ device const char * ids,
+ device const char * src1,
+ device float * dst,
+ constant uint64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ dst + bid*ne0,
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q4_1_f32")]]
+kernel void kernel_mul_mv_id_q4_1_f32(
+ device const char * ids,
+ device const char * src1,
+ device float * dst,
+ constant uint64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ dst + bid*ne0,
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q5_0_f32")]]
+kernel void kernel_mul_mv_id_q5_0_f32(
+ device const char * ids,
+ device const char * src1,
+ device float * dst,
+ constant uint64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ dst + bid*ne0,
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q5_1_f32")]]
+kernel void kernel_mul_mv_id_q5_1_f32(
+ device const char * ids,
+ device const char * src1,
+ device float * dst,
+ constant uint64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ dst + bid*ne0,
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q2_K_f32")]]
+kernel void kernel_mul_mv_id_q2_K_f32(
+ device const char * ids,
+ device const char * src1,
+ device float * dst,
+ constant uint64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_q2_K_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ dst + bid*ne0,
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q3_K_f32")]]
+kernel void kernel_mul_mv_id_q3_K_f32(
+ device const char * ids,
+ device const char * src1,
+ device float * dst,
+ constant uint64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_q3_K_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ dst + bid*ne0,
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q4_K_f32")]]
+kernel void kernel_mul_mv_id_q4_K_f32(
+ device const char * ids,
+ device const char * src1,
+ device float * dst,
+ constant uint64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_q4_K_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ dst + bid*ne0,
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q5_K_f32")]]
+kernel void kernel_mul_mv_id_q5_K_f32(
+ device const char * ids,
+ device const char * src1,
+ device float * dst,
+ constant uint64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_q5_K_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ dst + bid*ne0,
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q6_K_f32")]]
+kernel void kernel_mul_mv_id_q6_K_f32(
+ device const char * ids,
+ device const char * src1,
+ device float * dst,
+ constant uint64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_q6_K_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ dst + bid*ne0,
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs
index 87f8ac45..787a7d45 100644
--- a/candle-metal-kernels/src/tests.rs
+++ b/candle-metal-kernels/src/tests.rs
@@ -37,8 +37,7 @@ fn approx_bf16(v: Vec<bf16>, digits: i32) -> Vec<f32> {
fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
let device = device();
- let fence = device.new_fence();
- let kernels = Kernels::new(fence);
+ let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
@@ -60,8 +59,7 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> {
let device = device();
- let fence = device.new_fence();
- let kernels = Kernels::new(fence);
+ let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
@@ -96,8 +94,7 @@ fn run_strided<T: Clone>(
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let output = new_buffer(&device, v);
- let fence = device.new_fence();
- let kernels = Kernels::new(fence);
+ let kernels = Kernels::new();
call_unary_strided(
&device,
command_buffer,
@@ -278,8 +275,7 @@ fn binary_ops_bf16() {
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
let device = device();
- let fence = device.new_fence();
- let kernels = Kernels::new(fence);
+ let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
@@ -409,8 +405,7 @@ fn it_cast_f16_bf16() {
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
let device = device();
- let fence = device.new_fence();
- let kernels = Kernels::new(fence);
+ let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
@@ -445,8 +440,7 @@ fn run_affine_strided<T: Clone>(
add: f64,
) -> Vec<T> {
let device = device();
- let fence = device.new_fence();
- let kernels = Kernels::new(fence);
+ let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
@@ -595,8 +589,7 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
let dst_el = ids.len() * left_size * right_size;
let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
- let fence = device.new_fence();
- let kernels = Kernels::new(fence);
+ let kernels = Kernels::new();
call_index_select(
&device,
&command_buffer,
@@ -631,8 +624,7 @@ fn cos_f16() {
fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
let device = device();
- let fence = device.new_fence();
- let kernels = Kernels::new(fence);
+ let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
@@ -662,8 +654,7 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'static str) -> Vec<T> {
let device = device();
- let fence = device.new_fence();
- let kernels = Kernels::new(fence);
+ let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
@@ -782,8 +773,7 @@ fn run_where_cond<I: Clone, T: Clone>(
name: &'static str,
) -> Vec<T> {
let device = device();
- let fence = device.new_fence();
- let kernels = Kernels::new(fence);
+ let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
@@ -859,8 +849,7 @@ fn run_gemm<T: Clone>(
rhs_offset: usize,
) -> Vec<T> {
let device = device();
- let fence = device.new_fence();
- let kernels = Kernels::new(fence);
+ let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal
index dcf803d8..7add58fd 100644
--- a/candle-metal-kernels/src/unary.metal
+++ b/candle-metal-kernels/src/unary.metal
@@ -117,7 +117,6 @@ UNARY_OP(erf)
UNARY_OP(tanh)
UNARY_OP(recip)
UNARY_OP(relu)
-
UNARY(id, float, copy_f32, copy_f32_strided)
UNARY(id, half, copy_f16, copy_f16_strided)
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
@@ -136,6 +135,7 @@ BFLOAT_UNARY_OP(neg)
BFLOAT_UNARY_OP(exp)
BFLOAT_UNARY_OP(log)
BFLOAT_UNARY_OP(gelu)
+BFLOAT_UNARY_OP(abs)
BFLOAT_UNARY_OP(ceil)
BFLOAT_UNARY_OP(floor)
BFLOAT_UNARY_OP(round)
diff --git a/candle-nn/examples/cpu_benchmarks.rs b/candle-nn/examples/cpu_benchmarks.rs
index 68d384a6..001be116 100644
--- a/candle-nn/examples/cpu_benchmarks.rs
+++ b/candle-nn/examples/cpu_benchmarks.rs
@@ -222,7 +222,10 @@ impl Benchmark for QMatMul {
type RunResult = Tensor;
fn preprocess() -> Result<Self::PreProcessData> {
let zeros = vec![candle::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32];
- let mm = candle::quantized::QTensor::new(zeros, (4096, 11008))?;
+ let mm = candle::quantized::QTensor::new(
+ candle::quantized::QStorage::Cpu(Box::new(zeros)),
+ (4096, 11008),
+ )?;
let mm = candle::quantized::QMatMul::from_qtensor(mm)?;
let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?;
Ok((mm, arg))
diff --git a/candle-pyo3/py_src/candle/utils/__init__.pyi b/candle-pyo3/py_src/candle/utils/__init__.pyi
index 4ee51c29..c9a9f9f3 100644
--- a/candle-pyo3/py_src/candle/utils/__init__.pyi
+++ b/candle-pyo3/py_src/candle/utils/__init__.pyi
@@ -33,7 +33,9 @@ def has_mkl() -> bool:
pass
@staticmethod
-def load_ggml(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str, Any], List[str]]:
+def load_ggml(
+ path: Union[str, PathLike], device: Optional[Device] = None
+) -> Tuple[Dict[str, QTensor], Dict[str, Any], List[str]]:
"""
Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
@@ -41,7 +43,9 @@ def load_ggml(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str,
pass
@staticmethod
-def load_gguf(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str, Any]]:
+def load_gguf(
+ path: Union[str, PathLike], device: Optional[Device] = None
+) -> Tuple[Dict[str, QTensor], Dict[str, Any]]:
"""
Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
and the second maps metadata keys to metadata values.
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index 90826b98..ca406876 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -1074,20 +1074,20 @@ impl PyTensor {
fn quantize(&self, quantized_dtype: &str) -> PyResult<PyQTensor> {
use ::candle::quantized;
let res = match quantized_dtype.to_lowercase().as_str() {
- "q2k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ2K>(self),
- "q3k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ3K>(self),
- "q4_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4_0>(self),
- "q4_1" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4_1>(self),
- "q4k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4K>(self),
- "q5_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ5_0>(self),
- "q5_1" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ5_1>(self),
- "q5k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ5K>(self),
- "q6k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ6K>(self),
- "q8_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ8_0>(self),
- "q8_1" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ8_1>(self),
- "q8k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ8K>(self),
- "f16" => quantized::QTensor::quantize::<f16>(self),
- "f32" => quantized::QTensor::quantize::<f32>(self),
+ "q2k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q2K),
+ "q3k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q3K),
+ "q4_0" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q4_0),
+ "q4_1" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q4_1),
+ "q4k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q4K),
+ "q5_0" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q5_0),
+ "q5_1" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q5_1),
+ "q5k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q5K),
+ "q6k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q6K),
+ "q8_0" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q8_0),
+ "q8_1" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q8_1),
+ "q8k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q8K),
+ "f16" => quantized::QTensor::quantize(self, quantized::GgmlDType::F16),
+ "f32" => quantized::QTensor::quantize(self, quantized::GgmlDType::F32),
dt => {
return Err(PyErr::new::<PyValueError, _>(format!(
"unknown quantized-dtype {dt}"
@@ -1278,13 +1278,19 @@ fn save_safetensors(
}
#[pyfunction]
-#[pyo3(text_signature = "(path:Union[str,PathLike])")]
+#[pyo3(text_signature = "(path:Union[str,PathLike], device: Optional[Device] = None)")]
/// Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
/// a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
/// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any], List[str]]
-fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObject)> {
+fn load_ggml(
+ path: &str,
+ device: Option<PyDevice>,
+ py: Python<'_>,
+) -> PyResult<(PyObject, PyObject, PyObject)> {
let mut file = std::fs::File::open(path)?;
- let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?;
+ let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
+ let ggml =
+ ::candle::quantized::ggml_file::Content::read(&mut file, &device).map_err(wrap_err)?;
let tensors = ggml
.tensors
.into_iter()
@@ -1313,11 +1319,16 @@ fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObje
}
#[pyfunction]
-#[pyo3(text_signature = "(path:Union[str,PathLike])")]
+#[pyo3(text_signature = "(path:Union[str,PathLike], device: Optional[Device] = None)")]
/// Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
/// and the second maps metadata keys to metadata values.
/// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any]]
-fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
+fn load_gguf(
+ path: &str,
+ device: Option<PyDevice>,
+ py: Python<'_>,
+) -> PyResult<(PyObject, PyObject)> {
+ let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
use ::candle::quantized::gguf_file;
fn gguf_value_to_pyobject(v: &gguf_file::Value, py: Python<'_>) -> PyResult<PyObject> {
let v: PyObject = match v {
@@ -1349,7 +1360,7 @@ fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
.tensor_infos
.keys()
.map(|key| {
- let qtensor = gguf.tensor(&mut file, key)?;
+ let qtensor = gguf.tensor(&mut file, key, &device)?;
Ok((key, PyQTensor(Arc::new(qtensor)).into_py(py)))
})
.collect::<::candle::Result<Vec<_>>>()
diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs
index 1fb2d9e2..8aa06088 100644
--- a/candle-transformers/src/models/quantized_llama.rs
+++ b/candle-transformers/src/models/quantized_llama.rs
@@ -356,6 +356,7 @@ impl ModelWeights {
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
ct: gguf_file::Content,
reader: &mut R,
+ device: &Device,
) -> Result<Self> {
let cpu = &Device::Cpu;
let md_get = |s: &str| match ct.metadata.get(s) {
@@ -383,21 +384,28 @@ impl ModelWeights {
.unwrap_or(10000f32);
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?;
- let tok_embeddings = ct.tensor(reader, "token_embd.weight")?;
+ let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
- let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?, rms_norm_eps)?;
- let output = ct.tensor(reader, "output.weight")?;
+ let norm = RmsNorm::new(
+ ct.tensor(reader, "output_norm.weight", device)?,
+ rms_norm_eps,
+ )?;
+ let output = ct.tensor(reader, "output.weight", device)?;
let mut layers = Vec::with_capacity(block_count);
for layer_idx in 0..block_count {
let prefix = format!("blk.{layer_idx}");
- let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"))?;
- let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?;
- let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?;
- let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?;
+ let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?;
+ let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?;
+ let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
+ let attention_wo =
+ ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
let mlp_or_moe = if n_expert <= 1 {
- let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?;
- let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?;
- let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?;
+ let feed_forward_w1 =
+ ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
+ let feed_forward_w2 =
+ ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
+ let feed_forward_w3 =
+ ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
MlpOrMoe::Mlp(Mlp {
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
@@ -405,15 +413,15 @@ impl ModelWeights {
})
} else {
let feed_forward_gate_inp =
- ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"))?;
+ ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"), device)?;
let mut experts = Vec::with_capacity(n_expert);
for i in 0..n_expert {
let feed_forward_w1 =
- ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"))?;
+ ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"), device)?;
let feed_forward_w2 =
- ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"))?;
+ ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"), device)?;
let feed_forward_w3 =
- ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"))?;
+ ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"), device)?;
experts.push(Mlp {
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
@@ -426,8 +434,9 @@ impl ModelWeights {
experts,
}
};
- let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?;
- let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?;
+ let attention_norm =
+ ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?;
+ let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?;
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
diff --git a/candle-transformers/src/models/quantized_mixformer.rs b/candle-transformers/src/models/quantized_mixformer.rs
index 1a3cd4ac..882f4cf8 100644
--- a/candle-transformers/src/models/quantized_mixformer.rs
+++ b/candle-transformers/src/models/quantized_mixformer.rs
@@ -311,7 +311,7 @@ impl MixFormerSequentialForCausalLM {
let mut blocks = Vec::new();
for i in 0..cfg.n_layer {
let block = ParallelBlock::new(cfg, vb.pp(i + 1))?;
- blocks.push(block)
+ blocks.push(block);
}
let head = CausalLMHead::new(cfg, vb.pp(cfg.n_layer + 1))?;
Ok(Self {
@@ -332,7 +332,7 @@ impl MixFormerSequentialForCausalLM {
Some(get_mask(seq_len, xs.device())?)
};
for block in self.blocks.iter_mut() {
- xs = block.forward(&xs, mask.as_ref())?
+ xs = block.forward(&xs, mask.as_ref())?;
}
xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1)
}
diff --git a/candle-transformers/src/quantized_var_builder.rs b/candle-transformers/src/quantized_var_builder.rs
index 63101f4c..bfd0629f 100644
--- a/candle-transformers/src/quantized_var_builder.rs
+++ b/candle-transformers/src/quantized_var_builder.rs
@@ -10,33 +10,33 @@ pub struct VarBuilder {
}
impl VarBuilder {
- pub fn from_gguf<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
+ pub fn from_gguf<P: AsRef<std::path::Path>>(p: P, device: &Device) -> Result<Self> {
let mut file = std::fs::File::open(p)?;
let content = candle::quantized::gguf_file::Content::read(&mut file)?;
let mut data = std::collections::HashMap::new();
for tensor_name in content.tensor_infos.keys() {
- let tensor = content.tensor(&mut file, tensor_name)?;
+ let tensor = content.tensor(&mut file, tensor_name, device)?;
data.insert(tensor_name.to_string(), Arc::new(tensor));
}
Ok(Self {
data: Arc::new(data),
path: Vec::new(),
- device: Device::Cpu,
+ device: device.clone(),
})
}
- pub fn from_gguf_buffer(buffer: &[u8]) -> Result<Self> {
+ pub fn from_gguf_buffer(buffer: &[u8], device: &Device) -> Result<Self> {
let mut cursor = std::io::Cursor::new(buffer);
let content = candle::quantized::gguf_file::Content::read(&mut cursor)?;
let mut data = std::collections::HashMap::new();
for tensor_name in content.tensor_infos.keys() {
- let tensor = content.tensor(&mut cursor, tensor_name)?;
+ let tensor = content.tensor(&mut cursor, tensor_name, device)?;
data.insert(tensor_name.to_string(), Arc::new(tensor));
}
Ok(Self {
data: Arc::new(data),
path: Vec::new(),
- device: Device::Cpu,
+ device: device.clone(),
})
}
diff --git a/candle-wasm-examples/blip/src/bin/m.rs b/candle-wasm-examples/blip/src/bin/m.rs
index 660bb717..e2ba4fed 100644
--- a/candle-wasm-examples/blip/src/bin/m.rs
+++ b/candle-wasm-examples/blip/src/bin/m.rs
@@ -61,7 +61,7 @@ impl Model {
let start = Date::now();
let model: SelectedModel = if quantized {
- let vb = quantized_blip::VarBuilder::from_gguf_buffer(&weights)?;
+ let vb = quantized_blip::VarBuilder::from_gguf_buffer(&weights, &device)?;
let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
SelectedModel::Q(model)
} else {
diff --git a/candle-wasm-examples/phi/src/bin/m.rs b/candle-wasm-examples/phi/src/bin/m.rs
index 999f276d..859e58cb 100644
--- a/candle-wasm-examples/phi/src/bin/m.rs
+++ b/candle-wasm-examples/phi/src/bin/m.rs
@@ -41,6 +41,7 @@ impl Model {
) -> Result<Model, JsError> {
console_error_panic_hook::set_once();
console_log!("loading model");
+ let device = Device::Cpu;
let name: ModelName = serde_json::from_slice(&config)?;
let config: Config = serde_json::from_slice(&config)?;
@@ -50,8 +51,9 @@ impl Model {
let start = Date::now();
console_log!("weights len: {:?}", weights.len());
let model = if quantized {
- let vb =
- candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(&weights)?;
+ let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(
+ &weights, &device,
+ )?;
console_log!("weights loaded");
if name._name_or_path == "microsoft/phi-2" {
let model = QMixFormer::new_v2(&config, vb)?;
diff --git a/candle-wasm-examples/t5/src/bin/m-quantized.rs b/candle-wasm-examples/t5/src/bin/m-quantized.rs
index 2f490b84..3b99a275 100644
--- a/candle-wasm-examples/t5/src/bin/m-quantized.rs
+++ b/candle-wasm-examples/t5/src/bin/m-quantized.rs
@@ -7,6 +7,7 @@ pub use candle_transformers::models::quantized_t5::{
use candle_wasm_example_t5::console_log;
use tokenizers::Tokenizer;
use wasm_bindgen::prelude::*;
+const DEVICE: Device = Device::Cpu;
#[wasm_bindgen]
pub struct ModelEncoder {
@@ -31,7 +32,7 @@ impl ModelConditionalGeneration {
) -> Result<ModelConditionalGeneration, JsError> {
console_error_panic_hook::set_once();
console_log!("loading model");
- let vb = VarBuilder::from_gguf_buffer(&weights)?;
+ let vb = VarBuilder::from_gguf_buffer(&weights, &DEVICE)?;
let mut config: Config = serde_json::from_slice(&config)?;
let tokenizer =
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
@@ -46,7 +47,7 @@ impl ModelConditionalGeneration {
pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> {
let input: ConditionalGenerationParams =
serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
- let device = &Device::Cpu;
+ let device = &DEVICE;
self.model.clear_kv_cache();
let mut output_token_ids = [self.config.pad_token_id as u32].to_vec();
let prompt = input.prompt;
@@ -128,7 +129,7 @@ impl ModelEncoder {
) -> Result<ModelEncoder, JsError> {
console_error_panic_hook::set_once();
console_log!("loading model");
- let vb = VarBuilder::from_gguf_buffer(&weights)?;
+ let vb = VarBuilder::from_gguf_buffer(&weights, &DEVICE)?;
let mut config: Config = serde_json::from_slice(&config)?;
config.use_cache = false;
let tokenizer =
@@ -138,7 +139,7 @@ impl ModelEncoder {
}
pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> {
- let device = &Device::Cpu;
+ let device = &DEVICE;
let input: DecoderParams =
serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
diff --git a/candle-wasm-examples/whisper/src/worker.rs b/candle-wasm-examples/whisper/src/worker.rs
index fd91fa8c..898996a7 100644
--- a/candle-wasm-examples/whisper/src/worker.rs
+++ b/candle-wasm-examples/whisper/src/worker.rs
@@ -315,6 +315,7 @@ impl Decoder {
let model = if md.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(
&md.weights,
+ &device,
)?;
Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?)
} else {
diff --git a/candle-wasm-tests/tests/quantized_tests.rs b/candle-wasm-tests/tests/quantized_tests.rs
index e5fa7dec..fc107e61 100644
--- a/candle-wasm-tests/tests/quantized_tests.rs
+++ b/candle-wasm-tests/tests/quantized_tests.rs
@@ -40,7 +40,7 @@ fn quantized_matmul_neg() -> Result<()> {
]
);
- let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
+ let qtensor = quantized::QTensor::new(quantized::QStorage::Cpu(Box::new(rhs_t)), (4, 64))?;
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&tensor_lhs)?;
assert_eq!(