summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-20 18:55:45 +0200
committerGitHub <noreply@github.com>2024-04-20 18:55:45 +0200
commitdd78422701e9c6f3ca74218e8aedcf032c6c7215 (patch)
tree72c370795f8cd779c863275aaad3a73a66739fff /candle-metal-kernels
parent9215e9ce8c3fbe2e2850065557fc7e37b8e1c948 (diff)
downloadcandle-dd78422701e9c6f3ca74218e8aedcf032c6c7215.tar.gz
candle-dd78422701e9c6f3ca74218e8aedcf032c6c7215.tar.bz2
candle-dd78422701e9c6f3ca74218e8aedcf032c6c7215.zip
Handle multiple dimensions in metal QMM + two fixes. (#2097)
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/lib.rs15
1 files changed, 8 insertions, 7 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 78108127..e05797a2 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -1699,7 +1699,7 @@ pub enum GgmlDType {
}
#[allow(clippy::too_many_arguments)]
-pub fn call_quantized_matmul_t(
+pub fn call_quantized_matmul_mv_t(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
@@ -1708,7 +1708,8 @@ pub fn call_quantized_matmul_t(
lhs: &Buffer,
lhs_offset: usize,
rhs: &Buffer,
- output: &Buffer,
+ dst_offset: usize,
+ dst: &Buffer,
) -> Result<(), MetalKernelError> {
// Everything is in reverse
let ne00 = k as i64;
@@ -1748,8 +1749,9 @@ pub fn call_quantized_matmul_t(
}
GgmlDType::Q2K => {
// Fixing a bug in Metal for GGML
- let nth0 = 4;
- let nth1 = 8;
+ // https://github.com/ggerganov/llama.cpp/blob/b8109bc0139f15a5b321909f47510b89dca47ffc/ggml-metal.m#L1576
+ let nth0 = 2;
+ let nth1 = 32;
let align = 4;
(nth0, nth1, align)
}
@@ -1821,7 +1823,7 @@ pub fn call_quantized_matmul_t(
(
rhs,
(lhs, lhs_offset),
- output,
+ (dst, dst_offset),
ne00,
ne01,
ne02,
@@ -1840,10 +1842,9 @@ pub fn call_quantized_matmul_t(
r3
)
);
- encoder.set_threadgroup_memory_length(0, 8192);
encoder.use_resource(lhs, metal::MTLResourceUsage::Read);
encoder.use_resource(rhs, metal::MTLResourceUsage::Read);
- encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.use_resource(dst, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup);
encoder.end_encoding();