diff options
Diffstat (limited to 'candle-metal-kernels/src/lib.rs')
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 57 |
1 files changed, 49 insertions, 8 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 4cff9bda..8b9be670 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -40,6 +40,44 @@ fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTL (thread_group_count, thread_group_size) } +// https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96 +fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize { + let mut pows0 = 0u64; + let mut pows1 = 0u64; + let mut pows2 = 0u64; + let mut sum = 0u64; + loop { + let presum = sum; + // Check all the pows + if dim0 >= (1 << (pows0 + 1)) { + pows0 += 1; + sum += 1; + } + if sum == 10 { + break; + } + if dim1 >= (1 << (pows1 + 1)) { + pows1 += 1; + sum += 1; + } + if sum == 10 { + break; + } + if dim2 >= (1 << (pows2 + 1)) { + pows2 += 1; + sum += 1; + } + if sum == presum || sum == 10 { + break; + } + } + MTLSize { + width: 1 << pows0, + height: 1 << pows1, + depth: 1 << pows2, + } +} + fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) { <P as EncoderParam>::set_param(encoder, position, data) } @@ -396,21 +434,24 @@ pub fn call_copy2d( set_params!( encoder, ( - d1, - d2, - src_s, - dst_s, + d1 as i64, + d2 as i64, + src_s as i64, + dst_s as i64, (input, src_o_in_bytes), (output, dst_o_in_bytes) ) ); - let width: usize = d1 * d2; - let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); - + let grid_dims = MTLSize { + width: d1 as u64, + height: d2 as u64, + depth: 1, + }; + let group_dims = get_block_dims(d1 as u64, d2 as u64, 1); encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); - encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.dispatch_threads(grid_dims, group_dims); encoder.end_encoding(); Ok(()) } |