diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-04-07 12:34:16 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-07 12:34:16 +0200 |
commit | 7f354473cf495db4554e08f84be44ed498f1aa5e (patch) | |
tree | 494e7edc590b754d17cd6da4608edcb24d9dd239 /candle-metal-kernels | |
parent | 33c9b6655459bd1086574cef9ba8f2e72a8804c8 (diff) | |
download | candle-7f354473cf495db4554e08f84be44ed498f1aa5e.tar.gz candle-7f354473cf495db4554e08f84be44ed498f1aa5e.tar.bz2 candle-7f354473cf495db4554e08f84be44ed498f1aa5e.zip |
Optimize copy-2d for metal. (#2024)
* Optimize copy-2d for metal.
* Add a hacky stopping rule for moondream.
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 57 | ||||
-rw-r--r-- | candle-metal-kernels/src/unary.metal | 20 |
2 files changed, 57 insertions, 20 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 4cff9bda..8b9be670 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -40,6 +40,44 @@ fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTL (thread_group_count, thread_group_size) } +// https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96 +fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize { + let mut pows0 = 0u64; + let mut pows1 = 0u64; + let mut pows2 = 0u64; + let mut sum = 0u64; + loop { + let presum = sum; + // Check all the pows + if dim0 >= (1 << (pows0 + 1)) { + pows0 += 1; + sum += 1; + } + if sum == 10 { + break; + } + if dim1 >= (1 << (pows1 + 1)) { + pows1 += 1; + sum += 1; + } + if sum == 10 { + break; + } + if dim2 >= (1 << (pows2 + 1)) { + pows2 += 1; + sum += 1; + } + if sum == presum || sum == 10 { + break; + } + } + MTLSize { + width: 1 << pows0, + height: 1 << pows1, + depth: 1 << pows2, + } +} + fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) { <P as EncoderParam>::set_param(encoder, position, data) } @@ -396,21 +434,24 @@ pub fn call_copy2d( set_params!( encoder, ( - d1, - d2, - src_s, - dst_s, + d1 as i64, + d2 as i64, + src_s as i64, + dst_s as i64, (input, src_o_in_bytes), (output, dst_o_in_bytes) ) ); - let width: usize = d1 * d2; - let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); - + let grid_dims = MTLSize { + width: d1 as u64, + height: d2 as u64, + depth: 1, + }; + let group_dims = get_block_dims(d1 as u64, d2 as u64, 1); encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); - encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.dispatch_threads(grid_dims, group_dims); encoder.end_encoding(); Ok(()) } diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 809522d7..4b6363ed 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -104,21 +104,17 @@ UNARY(NAME, bfloat, NAME##_bf16, NAME##_bf16_strided); #define COPY2D(FN_NAME, TYPENAME) \ kernel void FN_NAME( \ - constant size_t &d1, \ - constant size_t &d2, \ - constant size_t &src_s, \ - constant size_t &dst_s, \ + constant int64_t &d1, \ + constant int64_t &d2, \ + constant int64_t &src_s, \ + constant int64_t &dst_s, \ device const TYPENAME *input, \ device TYPENAME *output, \ - uint tid [[ thread_position_in_grid ]] \ + uint2 idx [[thread_position_in_grid]] \ ) { \ - if (tid >= d1 * d2) { \ - return; \ - } \ - size_t idx1 = tid / d2; \ - size_t idx2 = tid - idx1 * d2; \ - size_t src_idx = idx1 * src_s + idx2; \ - size_t dst_idx = idx1 * dst_s + idx2; \ + if (idx.x >= d1 || idx.y >= d2) return; \ + int64_t src_idx = idx.x * src_s + idx.y; \ + int64_t dst_idx = idx.x * dst_s + idx.y; \ output[dst_idx] = input[src_idx]; \ } |