From 7f354473cf495db4554e08f84be44ed498f1aa5e Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 7 Apr 2024 12:34:16 +0200 Subject: Optimize copy-2d for metal. (#2024) * Optimize copy-2d for metal. * Add a hacky stopping rule for moondream. --- candle-metal-kernels/src/lib.rs | 57 +++++++++++++++++++++++++++++++----- candle-metal-kernels/src/unary.metal | 20 +++++-------- 2 files changed, 57 insertions(+), 20 deletions(-) (limited to 'candle-metal-kernels') diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 4cff9bda..8b9be670 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -40,6 +40,44 @@ fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTL (thread_group_count, thread_group_size) } +// https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96 +fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize { + let mut pows0 = 0u64; + let mut pows1 = 0u64; + let mut pows2 = 0u64; + let mut sum = 0u64; + loop { + let presum = sum; + // Check all the pows + if dim0 >= (1 << (pows0 + 1)) { + pows0 += 1; + sum += 1; + } + if sum == 10 { + break; + } + if dim1 >= (1 << (pows1 + 1)) { + pows1 += 1; + sum += 1; + } + if sum == 10 { + break; + } + if dim2 >= (1 << (pows2 + 1)) { + pows2 += 1; + sum += 1; + } + if sum == presum || sum == 10 { + break; + } + } + MTLSize { + width: 1 << pows0, + height: 1 << pows1, + depth: 1 << pows2, + } +} + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {

::set_param(encoder, position, data) } @@ -396,21 +434,24 @@ pub fn call_copy2d( set_params!( encoder, ( - d1, - d2, - src_s, - dst_s, + d1 as i64, + d2 as i64, + src_s as i64, + dst_s as i64, (input, src_o_in_bytes), (output, dst_o_in_bytes) ) ); - let width: usize = d1 * d2; - let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); - + let grid_dims = MTLSize { + width: d1 as u64, + height: d2 as u64, + depth: 1, + }; + let group_dims = get_block_dims(d1 as u64, d2 as u64, 1); encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); - encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.dispatch_threads(grid_dims, group_dims); encoder.end_encoding(); Ok(()) } diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 809522d7..4b6363ed 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -104,21 +104,17 @@ UNARY(NAME, bfloat, NAME##_bf16, NAME##_bf16_strided); #define COPY2D(FN_NAME, TYPENAME) \ kernel void FN_NAME( \ - constant size_t &d1, \ - constant size_t &d2, \ - constant size_t &src_s, \ - constant size_t &dst_s, \ + constant int64_t &d1, \ + constant int64_t &d2, \ + constant int64_t &src_s, \ + constant int64_t &dst_s, \ device const TYPENAME *input, \ device TYPENAME *output, \ - uint tid [[ thread_position_in_grid ]] \ + uint2 idx [[thread_position_in_grid]] \ ) { \ - if (tid >= d1 * d2) { \ - return; \ - } \ - size_t idx1 = tid / d2; \ - size_t idx2 = tid - idx1 * d2; \ - size_t src_idx = idx1 * src_s + idx2; \ - size_t dst_idx = idx1 * dst_s + idx2; \ + if (idx.x >= d1 || idx.y >= d2) return; \ + int64_t src_idx = idx.x * src_s + idx.y; \ + int64_t dst_idx = idx.x * dst_s + idx.y; \ output[dst_idx] = input[src_idx]; \ } -- cgit v1.2.3