summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/quantized/metal.rs35
-rw-r--r--candle-metal-kernels/src/lib.rs15
2 files changed, 28 insertions, 22 deletions
diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs
index c310d766..f7f5b68a 100644
--- a/candle-core/src/quantized/metal.rs
+++ b/candle-core/src/quantized/metal.rs
@@ -152,9 +152,9 @@ impl QMetalStorage {
// We always use a single batch dimension and stack all the tensors in the batch on the
// second dimension as the implementation in candle-metal-kernels doesn't handle batch
// properly.
- let (b, m) = match dst_shape.len() {
- 3 => (1, dst_shape[0] * dst_shape[1]),
- 2 => (1, dst_shape[0]),
+ let m = match dst_shape.len() {
+ 3 => dst_shape[0] * dst_shape[1],
+ 2 => dst_shape[0],
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
};
let last_k = dst_shape.pop().unwrap();
@@ -166,18 +166,23 @@ impl QMetalStorage {
let device = storage.device().clone();
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
let command_buffer = device.command_buffer()?;
- candle_metal_kernels::call_quantized_matmul_t(
- device.device(),
- &command_buffer,
- device.kernels(),
- self.dtype.into(),
- (b, m, n, k),
- storage.buffer(),
- layout.start_offset() * storage.dtype().size_in_bytes(),
- &self.buffer,
- &dst,
- )
- .map_err(MetalError::from)?;
+ // In some cases it would be better to use the mm variant, though it has its drawbacks
+ // around memory alignemnt.
+ for batch_id in 0..m {
+ candle_metal_kernels::call_quantized_matmul_mv_t(
+ device.device(),
+ &command_buffer,
+ device.kernels(),
+ self.dtype.into(),
+ (1, 1, n, k),
+ storage.buffer(),
+ (layout.start_offset() + batch_id * k) * storage.dtype().size_in_bytes(),
+ &self.buffer,
+ batch_id * n * DType::F32.size_in_bytes(),
+ &dst,
+ )
+ .map_err(MetalError::from)?;
+ }
let dst_storage = crate::MetalStorage::new(dst, device, dst_shape.elem_count(), DType::F32);
Ok((dst_storage, dst_shape))
}
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 78108127..e05797a2 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -1699,7 +1699,7 @@ pub enum GgmlDType {
}
#[allow(clippy::too_many_arguments)]
-pub fn call_quantized_matmul_t(
+pub fn call_quantized_matmul_mv_t(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
@@ -1708,7 +1708,8 @@ pub fn call_quantized_matmul_t(
lhs: &Buffer,
lhs_offset: usize,
rhs: &Buffer,
- output: &Buffer,
+ dst_offset: usize,
+ dst: &Buffer,
) -> Result<(), MetalKernelError> {
// Everything is in reverse
let ne00 = k as i64;
@@ -1748,8 +1749,9 @@ pub fn call_quantized_matmul_t(
}
GgmlDType::Q2K => {
// Fixing a bug in Metal for GGML
- let nth0 = 4;
- let nth1 = 8;
+ // https://github.com/ggerganov/llama.cpp/blob/b8109bc0139f15a5b321909f47510b89dca47ffc/ggml-metal.m#L1576
+ let nth0 = 2;
+ let nth1 = 32;
let align = 4;
(nth0, nth1, align)
}
@@ -1821,7 +1823,7 @@ pub fn call_quantized_matmul_t(
(
rhs,
(lhs, lhs_offset),
- output,
+ (dst, dst_offset),
ne00,
ne01,
ne02,
@@ -1840,10 +1842,9 @@ pub fn call_quantized_matmul_t(
r3
)
);
- encoder.set_threadgroup_memory_length(0, 8192);
encoder.use_resource(lhs, metal::MTLResourceUsage::Read);
encoder.use_resource(rhs, metal::MTLResourceUsage::Read);
- encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.use_resource(dst, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup);
encoder.end_encoding();