use metal::{Buffer, ComputeCommandEncoderRef, ComputePipelineState, MTLSize}; use std::ffi::c_void; /// Most kernels apply similarly across the tensors /// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the /// actual total buffer length). /// Then kernels can just do their op on their single point in the buffer. pub(crate) fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) { let size = length as u64; let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size); let count = (size + width - 1) / width; let thread_group_count = MTLSize { width: count, height: 1, depth: 1, }; let thread_group_size = MTLSize { width, height: 1, depth: 1, }; (thread_group_count, thread_group_size) } // https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96 pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize { let mut pows0 = 0u64; let mut pows1 = 0u64; let mut pows2 = 0u64; let mut sum = 0u64; loop { let presum = sum; // Check all the pows if dim0 >= (1 << (pows0 + 1)) { pows0 += 1; sum += 1; } if sum == 10 { break; } if dim1 >= (1 << (pows1 + 1)) { pows1 += 1; sum += 1; } if sum == 10 { break; } if dim2 >= (1 << (pows2 + 1)) { pows2 += 1; sum += 1; } if sum == presum || sum == 10 { break; } } MTLSize { width: 1 << pows0, height: 1 << pows1, depth: 1 << pows2, } } pub(crate) fn set_param( encoder: &ComputeCommandEncoderRef, position: u64, data: P, ) {

::set_param(encoder, position, data) } /// Helper functions to create the various objects on the compute command encoder /// on a single line. /// Prevents getting wrong some arguments number and mixing length and size in bytes. pub(crate) trait EncoderParam { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self); } macro_rules! primitive { ($type:ty) => { impl EncoderParam for $type { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { encoder.set_bytes( position, core::mem::size_of::<$type>() as u64, &data as *const $type as *const c_void, ); } } }; } primitive!(bool); primitive!(usize); primitive!(i32); primitive!(i64); primitive!(u32); primitive!(u64); primitive!(f32); pub struct BufferOffset<'a> { pub buffer: &'a Buffer, pub offset_in_bytes: usize, } impl<'a> BufferOffset<'a> { pub fn zero_offset(buffer: &'a Buffer) -> Self { Self { buffer, offset_in_bytes: 0, } } } impl EncoderParam for &[T] { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { encoder.set_bytes( position, core::mem::size_of_val(data) as u64, data.as_ptr() as *const c_void, ); } } impl EncoderParam for &Buffer { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { encoder.set_buffer(position, Some(data), 0); } } impl EncoderParam for (&Buffer, usize) { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { encoder.set_buffer(position, Some(data.0), data.1 as u64); } } impl<'a> EncoderParam for &BufferOffset<'a> { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { encoder.set_buffer(position, Some(data.buffer), data.offset_in_bytes as u64); } } impl EncoderParam for &mut Buffer { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { encoder.set_buffer(position, Some(data), 0); } } impl EncoderParam for (&mut Buffer, usize) { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { encoder.set_buffer(position, Some(data.0), data.1 as u64); } } #[macro_export] macro_rules! set_params { ($encoder:ident, ($($param:expr),+)) => ( let mut _index = 0; $( $crate::utils::set_param($encoder, _index, $param); _index += 1; )* ); } pub trait EncoderProvider { type Encoder<'a>: AsRef where Self: 'a; fn encoder(&self) -> Self::Encoder<'_>; } pub struct WrappedEncoder<'a> { inner: &'a ComputeCommandEncoderRef, end_encoding_on_drop: bool, } impl<'a> Drop for WrappedEncoder<'a> { fn drop(&mut self) { if self.end_encoding_on_drop { self.inner.end_encoding() } } } impl<'a> AsRef for WrappedEncoder<'a> { fn as_ref(&self) -> &metal::ComputeCommandEncoderRef { self.inner } } impl EncoderProvider for &metal::CommandBuffer { type Encoder<'a> = WrappedEncoder<'a> where Self: 'a; fn encoder(&self) -> Self::Encoder<'_> { WrappedEncoder { inner: self.new_compute_command_encoder(), end_encoding_on_drop: true, } } } impl EncoderProvider for &metal::CommandBufferRef { type Encoder<'a> = WrappedEncoder<'a> where Self: 'a; fn encoder(&self) -> Self::Encoder<'_> { WrappedEncoder { inner: self.new_compute_command_encoder(), end_encoding_on_drop: true, } } } impl EncoderProvider for &ComputeCommandEncoderRef { type Encoder<'a> = WrappedEncoder<'a> where Self: 'a; fn encoder(&self) -> Self::Encoder<'_> { WrappedEncoder { inner: self, end_encoding_on_drop: false, } } }