diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-04-07 22:37:53 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-07 22:37:53 +0200 |
commit | c5fe4a7f8983ae7c9641fa923f26ef60538aef06 (patch) | |
tree | 12ad3e2445577fc77a5f9879686d554aea943a0d /candle-metal-kernels | |
parent | 7f354473cf495db4554e08f84be44ed498f1aa5e (diff) | |
download | candle-c5fe4a7f8983ae7c9641fa923f26ef60538aef06.tar.gz candle-c5fe4a7f8983ae7c9641fa923f26ef60538aef06.tar.bz2 candle-c5fe4a7f8983ae7c9641fa923f26ef60538aef06.zip |
Rework the buffer offset logic for metal kernels (#2028)
* Move the metal kernels utils in a separate module.
* Use the BufferOffset for unary ops.
* Fix clippy lints.
* Use the new BufferOffset.
* Adapt the binary ops.
* Affine.
* More ops (powf, elu, cast).
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 289 | ||||
-rw-r--r-- | candle-metal-kernels/src/tests.rs | 58 | ||||
-rw-r--r-- | candle-metal-kernels/src/utils.rs | 162 |
3 files changed, 262 insertions, 247 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 8b9be670..23c072af 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1,11 +1,15 @@ use metal::{ - Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, - Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger, + Buffer, CommandBufferRef, CompileOptions, ComputePipelineState, Device, Function, + FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger, }; use std::collections::HashMap; use std::ffi::c_void; use std::sync::RwLock; +mod utils; +pub use utils::BufferOffset; +use utils::{get_block_dims, linear_split}; + const AFFINE: &str = include_str!("affine.metal"); const INDEXING: &str = include_str!("indexing.metal"); const UNARY: &str = include_str!("unary.metal"); @@ -18,138 +22,6 @@ const RANDOM: &str = include_str!("random.metal"); const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); const QUANTIZED: &str = include_str!("quantized.metal"); -/// Most kernels apply similarly across the tensors -/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the -/// actual total buffer length). -/// Then kernels can just do their op on their single point in the buffer. -fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) { - let size = length as u64; - let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size); - let count = (size + width - 1) / width; - let thread_group_count = MTLSize { - width: count, - height: 1, - depth: 1, - }; - - let thread_group_size = MTLSize { - width, - height: 1, - depth: 1, - }; - (thread_group_count, thread_group_size) -} - -// https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96 -fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize { - let mut pows0 = 0u64; - let mut pows1 = 0u64; - let mut pows2 = 0u64; - let mut sum = 0u64; - loop { - let presum = sum; - // Check all the pows - if dim0 >= (1 << (pows0 + 1)) { - pows0 += 1; - sum += 1; - } - if sum == 10 { - break; - } - if dim1 >= (1 << (pows1 + 1)) { - pows1 += 1; - sum += 1; - } - if sum == 10 { - break; - } - if dim2 >= (1 << (pows2 + 1)) { - pows2 += 1; - sum += 1; - } - if sum == presum || sum == 10 { - break; - } - } - MTLSize { - width: 1 << pows0, - height: 1 << pows1, - depth: 1 << pows2, - } -} - -fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) { - <P as EncoderParam>::set_param(encoder, position, data) -} - -/// Helper functions to create the various objects on the compute command encoder -/// on a single line. -/// Prevents getting wrong some arguments number and mixing length and size in bytes. -trait EncoderParam { - fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self); -} -macro_rules! primitive { - ($type:ty) => { - impl EncoderParam for $type { - fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { - encoder.set_bytes( - position, - core::mem::size_of::<$type>() as u64, - &data as *const $type as *const c_void, - ); - } - } - }; -} -primitive!(bool); -primitive!(usize); -primitive!(i32); -primitive!(i64); -primitive!(u32); -primitive!(u64); -primitive!(f32); - -impl<T> EncoderParam for &[T] { - fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { - encoder.set_bytes( - position, - core::mem::size_of_val(data) as u64, - data.as_ptr() as *const c_void, - ); - } -} - -impl EncoderParam for &Buffer { - fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { - encoder.set_buffer(position, Some(data), 0); - } -} -impl EncoderParam for (&Buffer, usize) { - fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { - encoder.set_buffer(position, Some(data.0), data.1 as u64); - } -} -impl EncoderParam for &mut Buffer { - fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { - encoder.set_buffer(position, Some(data), 0); - } -} -impl EncoderParam for (&mut Buffer, usize) { - fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { - encoder.set_buffer(position, Some(data.0), data.1 as u64); - } -} - -macro_rules! set_params { - ($encoder:ident, ($($param:expr),+)) => ( - let mut _index = 0; - $( - set_param($encoder, _index, $param); - _index += 1; - )* - ); -} - #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum Source { Affine, @@ -273,6 +145,12 @@ pub struct Kernels { pipelines: RwLock<Pipelines>, } +impl Default for Kernels { + fn default() -> Self { + Self::new() + } +} + impl Kernels { pub fn new() -> Self { let libraries = RwLock::new(Libraries::new()); @@ -396,17 +274,17 @@ pub fn call_unary_contiguous( kernels: &Kernels, kernel_name: unary::contiguous::Kernel, length: usize, - input: &Buffer, + input: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, input, output)); + set_params!(encoder, (length, &input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); - encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -463,11 +341,9 @@ pub fn call_unary_strided( kernels: &Kernels, name: unary::strided::Kernel, shape: &[usize], - input: &Buffer, + input: BufferOffset, strides: &[usize], - offset: usize, - output: &Buffer, - output_offset: usize, + output: BufferOffset, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?; @@ -476,23 +352,13 @@ pub fn call_unary_strided( encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); - set_params!( - encoder, - ( - length, - num_dims, - shape, - strides, - (input, offset), - (output, output_offset) - ) - ); + set_params!(encoder, (length, num_dims, shape, strides, &input, &output)); let width: usize = shape.iter().product(); let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); - encoder.use_resource(input, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); Ok(()) @@ -505,8 +371,8 @@ pub fn call_binary_contiguous( kernels: &Kernels, kernel_name: binary::contiguous::Kernel, length: usize, - left: &Buffer, - right: &Buffer, + left: BufferOffset, + right: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?; @@ -514,12 +380,12 @@ pub fn call_binary_contiguous( let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, left, right, output)); + set_params!(encoder, (length, &left, &right, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); - encoder.use_resource(left, metal::MTLResourceUsage::Read); - encoder.use_resource(right, metal::MTLResourceUsage::Read); + encoder.use_resource(left.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -533,12 +399,10 @@ pub fn call_binary_strided( kernels: &Kernels, name: binary::strided::Kernel, shape: &[usize], - left_input: &Buffer, + left_input: BufferOffset, left_strides: &[usize], - left_offset: usize, - right_input: &Buffer, + right_input: BufferOffset, right_strides: &[usize], - right_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?; @@ -558,16 +422,16 @@ pub fn call_binary_strided( shape, left_strides, right_strides, - (left_input, left_offset), - (right_input, right_offset), + &left_input, + &right_input, output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); - encoder.use_resource(left_input, metal::MTLResourceUsage::Read); - encoder.use_resource(right_input, metal::MTLResourceUsage::Read); + encoder.use_resource(left_input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(right_input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -581,8 +445,7 @@ pub fn call_cast_contiguous( kernels: &Kernels, kernel_name: &'static str, length: usize, - input: &Buffer, - input_offset: usize, + input: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; @@ -590,10 +453,10 @@ pub fn call_cast_contiguous( let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, (input, input_offset), output)); + set_params!(encoder, (length, &input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); - encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -607,9 +470,8 @@ pub fn call_cast_strided( kernels: &Kernels, kernel_name: &'static str, shape: &[usize], - input: &Buffer, + input: BufferOffset, input_strides: &[usize], - input_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; @@ -621,25 +483,19 @@ pub fn call_cast_strided( set_params!( encoder, - ( - length, - shape.len(), - shape, - input_strides, - (input, input_offset), - output - ) + (length, shape.len(), shape, input_strides, &input, output) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); - encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); Ok(()) } +#[allow(clippy::too_many_arguments)] pub fn call_reduce_contiguous( device: &Device, command_buffer: &CommandBufferRef, @@ -687,6 +543,7 @@ pub fn call_reduce_contiguous( Ok(()) } +#[allow(clippy::too_many_arguments)] pub fn call_reduce_strided( device: &Device, command_buffer: &CommandBufferRef, @@ -985,7 +842,7 @@ pub fn call_affine( kernels: &Kernels, name: &'static str, size: usize, - input: &Buffer, + input: BufferOffset, output: &Buffer, mul: f32, add: f32, @@ -995,10 +852,10 @@ pub fn call_affine( let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (size, mul, add, input, output)); + set_params!(encoder, (size, mul, add, &input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); - encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -1012,9 +869,8 @@ pub fn call_affine_strided( kernels: &Kernels, name: &'static str, shape: &[usize], - input: &Buffer, + input: BufferOffset, input_stride: &[usize], - input_offset: usize, output: &Buffer, mul: f32, add: f32, @@ -1034,13 +890,13 @@ pub fn call_affine_strided( input_stride, mul, add, - (input, input_offset), + &input, output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); - encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -1054,7 +910,7 @@ pub fn call_powf( kernels: &Kernels, name: &'static str, size: usize, - input: &Buffer, + input: BufferOffset, output: &Buffer, mul: f32, ) -> Result<(), MetalKernelError> { @@ -1063,10 +919,10 @@ pub fn call_powf( let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (size, mul, input, output)); + set_params!(encoder, (size, mul, &input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); - encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -1080,9 +936,8 @@ pub fn call_powf_strided( kernels: &Kernels, name: &'static str, shape: &[usize], - input: &Buffer, + input: BufferOffset, input_stride: &[usize], - input_offset: usize, output: &Buffer, mul: f32, ) -> Result<(), MetalKernelError> { @@ -1094,19 +949,11 @@ pub fn call_powf_strided( set_params!( encoder, - ( - size, - shape.len(), - shape, - input_stride, - mul, - (input, input_offset), - output - ) + (size, shape.len(), shape, input_stride, mul, &input, output) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); - encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -1120,7 +967,7 @@ pub fn call_elu( kernels: &Kernels, name: &'static str, size: usize, - input: &Buffer, + input: BufferOffset, output: &Buffer, mul: f32, ) -> Result<(), MetalKernelError> { @@ -1129,10 +976,10 @@ pub fn call_elu( let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (size, mul, input, output)); + set_params!(encoder, (size, mul, &input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); - encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -1146,9 +993,8 @@ pub fn call_elu_strided( kernels: &Kernels, name: &'static str, shape: &[usize], - input: &Buffer, + input: BufferOffset, input_stride: &[usize], - input_offset: usize, output: &Buffer, mul: f32, ) -> Result<(), MetalKernelError> { @@ -1160,25 +1006,18 @@ pub fn call_elu_strided( set_params!( encoder, - ( - size, - shape.len(), - shape, - input_stride, - mul, - (input, input_offset), - output - ) + (size, shape.len(), shape, input_stride, mul, &input, output) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); - encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); Ok(()) } +#[allow(clippy::too_many_arguments)] pub fn call_where_cond_strided( device: &Device, command_buffer: &CommandBufferRef, @@ -1334,6 +1173,7 @@ pub fn call_gather( Ok(()) } +#[allow(clippy::too_many_arguments)] pub fn call_scatter_add( device: &Device, command_buffer: &CommandBufferRef, @@ -1384,6 +1224,7 @@ pub fn call_scatter_add( Ok(()) } +#[allow(clippy::too_many_arguments)] pub fn call_index_add( device: &Device, command_buffer: &CommandBufferRef, @@ -1910,6 +1751,7 @@ pub enum GgmlDType { F32, } +#[allow(clippy::too_many_arguments)] pub fn call_quantized_matmul_t( device: &Device, command_buffer: &CommandBufferRef, @@ -1925,16 +1767,16 @@ pub fn call_quantized_matmul_t( let ne00 = k as i64; let ne01 = n as i64; let ne02 = b as i64; - let ne03 = 1 as i64; + let ne03 = 1i64; let nb00 = 0i64; - let nb01 = 0 as i64; - let nb02 = 0 as i64; + let nb01 = 0i64; + let nb02 = 0i64; let ne10 = k as i64; let ne11 = m as i64; let ne12 = b as i64; - let ne13 = 1 as i64; + let ne13 = 1i64; let nb10 = 0i64; let nb11 = 0i64; @@ -2169,6 +2011,7 @@ pub struct CallConvTranspose2dCfg<'a> { pub kernel_offset: usize, } +#[allow(clippy::too_many_arguments)] pub fn call_conv_transpose2d( device: &Device, command_buffer: &CommandBufferRef, diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index b15d9b36..b91c92d8 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -12,7 +12,7 @@ fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> { fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer { let options = MTLResourceOptions::StorageModeManaged; let ptr = data.as_ptr() as *const c_void; - let size = (data.len() * std::mem::size_of::<T>()) as u64; + let size = std::mem::size_of_val(data) as u64; device.new_buffer_with_data(ptr, size, options) } @@ -41,6 +41,10 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> { let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); + let input = BufferOffset { + buffer: &input, + offset_in_bytes: 0, + }; let output = new_buffer(&device, v); call_unary_contiguous( &device, @@ -48,7 +52,7 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> { &kernels, name, v.len(), - &input, + input, &output, ) .unwrap(); @@ -72,8 +76,8 @@ fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> V &kernels, name, x.len(), - &left, - &right, + BufferOffset::zero_offset(&left), + BufferOffset::zero_offset(&right), &output, ) .unwrap(); @@ -93,7 +97,15 @@ fn run_strided<T: Clone>( let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); - let output = new_buffer(&device, v); + let input = BufferOffset { + buffer: &input, + offset_in_bytes: offset, + }; + let output_b = new_buffer(&device, v); + let output = BufferOffset { + buffer: &output_b, + offset_in_bytes: 0, + }; let kernels = Kernels::new(); call_unary_strided( &device, @@ -101,16 +113,14 @@ fn run_strided<T: Clone>( &kernels, kernel, shape, - &input, + input, strides, - offset, - &output, - 0, + output, ) .unwrap(); command_buffer.commit(); command_buffer.wait_until_completed(); - read_to_vec(&output, v.len()) + read_to_vec(&output_b, v.len()) } #[test] @@ -308,8 +318,7 @@ fn run_cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> { &kernels, name, v.len(), - &input, - 0, + BufferOffset::zero_offset(&input), &output, ) .unwrap(); @@ -521,7 +530,7 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> { &kernels, "affine_f32", size, - &input, + BufferOffset::zero_offset(&input), &output, mul as f32, add as f32, @@ -554,9 +563,8 @@ fn run_affine_strided<T: Clone>( &kernels, "affine_f32_strided", shape, - &input, + BufferOffset::zero_offset(&input), strides, - 0, &output, mul as f32, add as f32, @@ -633,7 +641,7 @@ fn index_select_strided() { fn index_select_f16() { let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] .into_iter() - .map(|x| f16::from_f32(x)) + .map(f16::from_f32) .collect(); let shape = [5, 2]; let stride = [2, 1]; @@ -700,8 +708,8 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>( let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); - let embeddings_buffer = new_buffer(&device, &embeddings); - let ids_buffer = new_buffer(&device, &ids); + let embeddings_buffer = new_buffer(&device, embeddings); + let ids_buffer = new_buffer(&device, ids); let left_size: usize = shape[..dim].iter().product(); let right_size: usize = shape[dim + 1..].iter().product(); @@ -711,7 +719,7 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>( let kernels = Kernels::new(); call_index_select( &device, - &command_buffer, + command_buffer, &kernels, name, shape, @@ -746,8 +754,8 @@ fn run_index_select_strided<T: Clone, I: Clone + std::fmt::Debug>( let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); - let embeddings_buffer = new_buffer(&device, &embeddings); - let ids_buffer = new_buffer(&device, &ids); + let embeddings_buffer = new_buffer(&device, embeddings); + let ids_buffer = new_buffer(&device, ids); let left_size: usize = shape[..dim].iter().product(); let right_size: usize = shape[dim + 1..].iter().product(); @@ -757,7 +765,7 @@ fn run_index_select_strided<T: Clone, I: Clone + std::fmt::Debug>( let kernels = Kernels::new(); call_index_select( &device, - &command_buffer, + command_buffer, &kernels, name, shape, @@ -931,6 +939,7 @@ fn softmax() { ); } +#[allow(clippy::too_many_arguments)] fn run_where_cond<I: Clone, T: Clone>( shape: &[usize], cond: &[I], @@ -1148,7 +1157,7 @@ fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: #[test] fn random() { fn calc_mean(data: &[f32]) -> f32 { - let sum = data.iter().sum::<f32>() as f32; + let sum = data.iter().sum::<f32>(); let count = data.len(); assert!(count > 0); sum / count as f32 @@ -1162,7 +1171,7 @@ fn random() { let variance = data .iter() .map(|value| { - let diff = mean - (*value as f32); + let diff = mean - *value; diff * diff }) .sum::<f32>() @@ -1787,6 +1796,7 @@ fn avg_pool2d_u32() { assert_eq!(results, expected); } +#[allow(clippy::too_many_arguments)] fn run_conv_transpose1d<T: Clone>( input: &[T], input_shape: &[usize], diff --git a/candle-metal-kernels/src/utils.rs b/candle-metal-kernels/src/utils.rs new file mode 100644 index 00000000..194cddf4 --- /dev/null +++ b/candle-metal-kernels/src/utils.rs @@ -0,0 +1,162 @@ +use metal::{Buffer, ComputeCommandEncoderRef, ComputePipelineState, MTLSize}; +use std::ffi::c_void; + +/// Most kernels apply similarly across the tensors +/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the +/// actual total buffer length). +/// Then kernels can just do their op on their single point in the buffer. +pub(crate) fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) { + let size = length as u64; + let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size); + let count = (size + width - 1) / width; + let thread_group_count = MTLSize { + width: count, + height: 1, + depth: 1, + }; + + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + (thread_group_count, thread_group_size) +} + +// https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96 +pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize { + let mut pows0 = 0u64; + let mut pows1 = 0u64; + let mut pows2 = 0u64; + let mut sum = 0u64; + loop { + let presum = sum; + // Check all the pows + if dim0 >= (1 << (pows0 + 1)) { + pows0 += 1; + sum += 1; + } + if sum == 10 { + break; + } + if dim1 >= (1 << (pows1 + 1)) { + pows1 += 1; + sum += 1; + } + if sum == 10 { + break; + } + if dim2 >= (1 << (pows2 + 1)) { + pows2 += 1; + sum += 1; + } + if sum == presum || sum == 10 { + break; + } + } + MTLSize { + width: 1 << pows0, + height: 1 << pows1, + depth: 1 << pows2, + } +} + +pub(crate) fn set_param<P: EncoderParam>( + encoder: &ComputeCommandEncoderRef, + position: u64, + data: P, +) { + <P as EncoderParam>::set_param(encoder, position, data) +} + +/// Helper functions to create the various objects on the compute command encoder +/// on a single line. +/// Prevents getting wrong some arguments number and mixing length and size in bytes. +pub(crate) trait EncoderParam { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self); +} +macro_rules! primitive { + ($type:ty) => { + impl EncoderParam for $type { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_bytes( + position, + core::mem::size_of::<$type>() as u64, + &data as *const $type as *const c_void, + ); + } + } + }; +} +primitive!(bool); +primitive!(usize); +primitive!(i32); +primitive!(i64); +primitive!(u32); +primitive!(u64); +primitive!(f32); + +pub struct BufferOffset<'a> { + pub buffer: &'a Buffer, + pub offset_in_bytes: usize, +} + +impl<'a> BufferOffset<'a> { + pub fn zero_offset(buffer: &'a Buffer) -> Self { + Self { + buffer, + offset_in_bytes: 0, + } + } +} + +impl<T> EncoderParam for &[T] { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_bytes( + position, + core::mem::size_of_val(data) as u64, + data.as_ptr() as *const c_void, + ); + } +} + +impl EncoderParam for &Buffer { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_buffer(position, Some(data), 0); + } +} + +impl EncoderParam for (&Buffer, usize) { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_buffer(position, Some(data.0), data.1 as u64); + } +} + +impl<'a> EncoderParam for &BufferOffset<'a> { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_buffer(position, Some(data.buffer), data.offset_in_bytes as u64); + } +} + +impl EncoderParam for &mut Buffer { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_buffer(position, Some(data), 0); + } +} + +impl EncoderParam for (&mut Buffer, usize) { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_buffer(position, Some(data.0), data.1 as u64); + } +} + +#[macro_export] +macro_rules! set_params { + ($encoder:ident, ($($param:expr),+)) => ( + let mut _index = 0; + $( + $crate::utils::set_param($encoder, _index, $param); + _index += 1; + )* + ); +} |