diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-04-20 18:55:45 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-20 18:55:45 +0200 |
commit | dd78422701e9c6f3ca74218e8aedcf032c6c7215 (patch) | |
tree | 72c370795f8cd779c863275aaad3a73a66739fff /candle-metal-kernels | |
parent | 9215e9ce8c3fbe2e2850065557fc7e37b8e1c948 (diff) | |
download | candle-dd78422701e9c6f3ca74218e8aedcf032c6c7215.tar.gz candle-dd78422701e9c6f3ca74218e8aedcf032c6c7215.tar.bz2 candle-dd78422701e9c6f3ca74218e8aedcf032c6c7215.zip |
Handle multiple dimensions in metal QMM + two fixes. (#2097)
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 15 |
1 files changed, 8 insertions, 7 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 78108127..e05797a2 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1699,7 +1699,7 @@ pub enum GgmlDType { } #[allow(clippy::too_many_arguments)] -pub fn call_quantized_matmul_t( +pub fn call_quantized_matmul_mv_t( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, @@ -1708,7 +1708,8 @@ pub fn call_quantized_matmul_t( lhs: &Buffer, lhs_offset: usize, rhs: &Buffer, - output: &Buffer, + dst_offset: usize, + dst: &Buffer, ) -> Result<(), MetalKernelError> { // Everything is in reverse let ne00 = k as i64; @@ -1748,8 +1749,9 @@ pub fn call_quantized_matmul_t( } GgmlDType::Q2K => { // Fixing a bug in Metal for GGML - let nth0 = 4; - let nth1 = 8; + // https://github.com/ggerganov/llama.cpp/blob/b8109bc0139f15a5b321909f47510b89dca47ffc/ggml-metal.m#L1576 + let nth0 = 2; + let nth1 = 32; let align = 4; (nth0, nth1, align) } @@ -1821,7 +1823,7 @@ pub fn call_quantized_matmul_t( ( rhs, (lhs, lhs_offset), - output, + (dst, dst_offset), ne00, ne01, ne02, @@ -1840,10 +1842,9 @@ pub fn call_quantized_matmul_t( r3 ) ); - encoder.set_threadgroup_memory_length(0, 8192); encoder.use_resource(lhs, metal::MTLResourceUsage::Read); encoder.use_resource(rhs, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(dst, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup); encoder.end_encoding(); |