diff options
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 240 | ||||
-rw-r--r-- | candle-metal-kernels/src/utils.rs | 23 |
2 files changed, 143 insertions, 120 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 1815dd32..6f723a93 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1,6 +1,6 @@ use metal::{ - Buffer, CommandBufferRef, CompileOptions, ComputePipelineState, Device, Function, - FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger, + Buffer, CompileOptions, ComputePipelineState, Device, Function, FunctionConstantValues, + Library, MTLDataType, MTLSize, NSUInteger, }; use std::collections::HashMap; use std::ffi::c_void; @@ -8,7 +8,7 @@ use std::sync::RwLock; mod utils; pub use utils::BufferOffset; -use utils::{get_block_dims, linear_split}; +use utils::{get_block_dims, linear_split, EncoderProvider}; const AFFINE: &str = include_str!("affine.metal"); const INDEXING: &str = include_str!("indexing.metal"); @@ -297,7 +297,7 @@ impl Kernels { #[allow(clippy::too_many_arguments)] pub fn call_copy2d( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, name: copy2d::Kernel, input: &Buffer, @@ -310,7 +310,7 @@ pub fn call_copy2d( dst_o_in_bytes: usize, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?; - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, @@ -333,14 +333,14 @@ pub fn call_copy2d( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_threads(grid_dims, group_dims); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_unary_contiguous_tiled( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, kernel_name: unary::contiguous_tiled::Kernel, length: usize, @@ -348,7 +348,7 @@ pub fn call_unary_contiguous_tiled( output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); let tile_size = 2; let tiles = (length + tile_size - 1) / tile_size; @@ -360,14 +360,14 @@ pub fn call_unary_contiguous_tiled( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_unary_contiguous( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, kernel_name: unary::contiguous::Kernel, length: usize, @@ -375,7 +375,7 @@ pub fn call_unary_contiguous( output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); encoder.set_compute_pipeline_state(&pipeline); @@ -385,14 +385,14 @@ pub fn call_unary_contiguous( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_unary_strided( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, name: unary::strided::Kernel, shape: &[usize], @@ -404,7 +404,7 @@ pub fn call_unary_strided( let length: usize = shape.iter().product(); let num_dims: usize = shape.len(); - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.set_compute_pipeline_state(&pipeline); @@ -412,14 +412,14 @@ pub fn call_unary_strided( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_binary_contiguous( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, kernel_name: binary::contiguous::Kernel, length: usize, @@ -429,7 +429,7 @@ pub fn call_binary_contiguous( ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?; - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, &left, &right, output)); @@ -440,14 +440,14 @@ pub fn call_binary_contiguous( encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_binary_strided( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, name: binary::strided::Kernel, shape: &[usize], @@ -460,7 +460,7 @@ pub fn call_binary_strided( let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?; let num_dims: usize = shape.len(); - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); let width: usize = shape.iter().product(); let length: usize = shape.iter().product(); let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); @@ -483,7 +483,7 @@ pub fn call_binary_strided( encoder.use_resource(right_input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } @@ -491,7 +491,7 @@ pub fn call_binary_strided( #[allow(clippy::too_many_arguments)] pub fn call_cast_contiguous( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, kernel_name: &'static str, length: usize, @@ -500,7 +500,7 @@ pub fn call_cast_contiguous( ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, &input, output)); @@ -509,14 +509,14 @@ pub fn call_cast_contiguous( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_cast_strided( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, kernel_name: &'static str, shape: &[usize], @@ -526,7 +526,7 @@ pub fn call_cast_strided( ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); @@ -541,14 +541,14 @@ pub fn call_cast_strided( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_reduce_contiguous( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, kernel_name: &'static str, length: usize, @@ -559,7 +559,7 @@ pub fn call_reduce_contiguous( let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let elements_to_sum = length / out_length; - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, elements_to_sum, &input, output)); @@ -585,14 +585,14 @@ pub fn call_reduce_contiguous( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_reduce_strided( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, kernel_name: &'static str, shape: &[usize], @@ -605,7 +605,7 @@ pub fn call_reduce_strided( let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let elements_to_sum = length / out_length; - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -634,14 +634,14 @@ pub fn call_reduce_strided( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_last_softmax( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, kernel_name: &'static str, length: usize, @@ -651,7 +651,7 @@ pub fn call_last_softmax( output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -682,14 +682,14 @@ pub fn call_last_softmax( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_rms_norm( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, kernel_name: &'static str, length: usize, @@ -702,7 +702,7 @@ pub fn call_rms_norm( output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -741,14 +741,14 @@ pub fn call_rms_norm( encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.set_threadgroup_memory_length(0, (width * 4).max(16) as u64); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_layer_norm( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, kernel_name: &'static str, length: usize, @@ -763,7 +763,7 @@ pub fn call_layer_norm( output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -803,14 +803,14 @@ pub fn call_layer_norm( encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.set_threadgroup_memory_length(0, (width * 8).max(32) as u64); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_rope_i( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, kernel_name: &'static str, bh: usize, @@ -824,7 +824,7 @@ pub fn call_rope_i( output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -844,14 +844,14 @@ pub fn call_rope_i( encoder.use_resource(sin, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_rope_thd( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, kernel_name: &'static str, b: usize, @@ -867,7 +867,7 @@ pub fn call_rope_thd( output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -889,14 +889,14 @@ pub fn call_rope_thd( encoder.use_resource(sin, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_rope( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, kernel_name: &'static str, bh: usize, @@ -911,7 +911,7 @@ pub fn call_rope( output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -932,14 +932,14 @@ pub fn call_rope( encoder.use_resource(sin, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_affine( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, size: usize, @@ -950,7 +950,7 @@ pub fn call_affine( ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (size, mul, add, &input, output)); @@ -959,14 +959,14 @@ pub fn call_affine( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_affine_strided( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, shape: &[usize], @@ -979,7 +979,7 @@ pub fn call_affine_strided( let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let size: usize = shape.iter().product(); - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -1000,14 +1000,14 @@ pub fn call_affine_strided( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_powf( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, size: usize, @@ -1017,7 +1017,7 @@ pub fn call_powf( ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (size, mul, &input, output)); @@ -1026,14 +1026,14 @@ pub fn call_powf( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_powf_strided( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, shape: &[usize], @@ -1045,7 +1045,7 @@ pub fn call_powf_strided( let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let size: usize = shape.iter().product(); - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -1057,14 +1057,14 @@ pub fn call_powf_strided( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_elu( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, size: usize, @@ -1074,7 +1074,7 @@ pub fn call_elu( ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (size, mul, &input, output)); @@ -1083,14 +1083,14 @@ pub fn call_elu( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_elu_strided( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, shape: &[usize], @@ -1102,7 +1102,7 @@ pub fn call_elu_strided( let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let size: usize = shape.iter().product(); - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -1114,14 +1114,14 @@ pub fn call_elu_strided( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_where_cond_strided( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, shape: &[usize], @@ -1135,7 +1135,7 @@ pub fn call_where_cond_strided( ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?; - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); encoder.set_compute_pipeline_state(&pipeline); let size: usize = shape.iter().product(); @@ -1164,14 +1164,14 @@ pub fn call_where_cond_strided( encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_index_select( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, shape: &[usize], @@ -1191,7 +1191,7 @@ pub fn call_index_select( let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); encoder.set_compute_pipeline_state(&pipeline); @@ -1218,14 +1218,14 @@ pub fn call_index_select( encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_gather( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, shape: &[usize], @@ -1242,7 +1242,7 @@ pub fn call_gather( let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); encoder.set_compute_pipeline_state(&pipeline); @@ -1266,14 +1266,14 @@ pub fn call_gather( encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_scatter_add( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, src_shape: &[usize], @@ -1291,7 +1291,7 @@ pub fn call_scatter_add( let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); encoder.set_compute_pipeline_state(&pipeline); @@ -1315,14 +1315,14 @@ pub fn call_scatter_add( encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_index_add( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, src_shape: &[usize], @@ -1341,7 +1341,7 @@ pub fn call_index_add( let ids_dim_size = ids_shape[0]; let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); encoder.set_compute_pipeline_state(&pipeline); @@ -1366,7 +1366,7 @@ pub fn call_index_add( encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } @@ -1453,7 +1453,7 @@ impl ConstantValues { #[allow(clippy::too_many_arguments)] pub fn call_gemm( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, (b, m, n, k): (usize, usize, usize, usize), @@ -1572,7 +1572,7 @@ pub fn call_gemm( }; let block_bytes = block_elements * bytes; - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); encoder.set_compute_pipeline_state(&pipeline); encoder.set_threadgroup_memory_length(0, block_bytes.into()); encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger); @@ -1615,7 +1615,7 @@ pub fn call_gemm( encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(grid_size, group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } @@ -1623,7 +1623,7 @@ pub fn call_gemm( #[allow(clippy::too_many_arguments)] pub fn call_im2col1d_strided( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, shape: &[usize], @@ -1636,7 +1636,7 @@ pub fn call_im2col1d_strided( let l_out = (shape[2] + 2 * padding - dilation * (k_size - 1) - 1) / stride + 1; let dst_el = shape[0] * l_out * shape[1] * k_size; - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -1646,7 +1646,7 @@ pub fn call_im2col1d_strided( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } @@ -1654,7 +1654,7 @@ pub fn call_im2col1d_strided( #[allow(clippy::too_many_arguments)] pub fn call_col2im1d( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, shape: &[usize], @@ -1669,7 +1669,7 @@ pub fn call_col2im1d( let l_out = (l_in - 1) * stride + k_size; let dst_el = shape[0] * c_out * l_out; - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -1679,7 +1679,7 @@ pub fn call_col2im1d( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } @@ -1687,7 +1687,7 @@ pub fn call_col2im1d( #[allow(clippy::too_many_arguments)] pub fn call_im2col_strided( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, shape: &[usize], @@ -1705,7 +1705,7 @@ pub fn call_im2col_strided( let dst_el = shape[0] * h_out * w_out * shape[1] * h_k * w_k; - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -1718,7 +1718,7 @@ pub fn call_im2col_strided( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } @@ -1726,7 +1726,7 @@ pub fn call_im2col_strided( #[allow(clippy::too_many_arguments)] pub fn call_upsample_nearest_2d( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, shape: &[usize], @@ -1741,7 +1741,7 @@ pub fn call_upsample_nearest_2d( let scale_w = shape[2] as f32 / out_w as f32; let scale_h = shape[3] as f32 / out_h as f32; let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, @@ -1750,7 +1750,7 @@ pub fn call_upsample_nearest_2d( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } @@ -1758,7 +1758,7 @@ pub fn call_upsample_nearest_2d( #[allow(clippy::too_many_arguments)] pub fn call_random_uniform( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, min: f32, @@ -1773,7 +1773,7 @@ pub fn call_random_uniform( )); } let pipeline = kernels.load_pipeline(device, Source::Random, name)?; - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); let odd = (length % 2 != 0) as usize; let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd); @@ -1788,7 +1788,7 @@ pub fn call_random_uniform( ); encoder.use_resource(buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } @@ -1796,7 +1796,7 @@ pub fn call_random_uniform( #[allow(clippy::too_many_arguments)] pub fn call_random_normal( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, mean: f32, @@ -1806,7 +1806,7 @@ pub fn call_random_normal( buffer: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Random, name)?; - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); let odd = (length % 2 != 0) as usize; let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd); @@ -1821,7 +1821,7 @@ pub fn call_random_normal( ); encoder.use_resource(buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } @@ -1847,7 +1847,7 @@ pub enum GgmlDType { #[allow(clippy::too_many_arguments)] pub fn call_quantized_matmul_mv_t( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, dtype: GgmlDType, (b, m, n, k): (usize, usize, usize, usize), @@ -1961,7 +1961,7 @@ pub fn call_quantized_matmul_mv_t( }; let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?; - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -1993,7 +1993,7 @@ pub fn call_quantized_matmul_mv_t( encoder.use_resource(dst, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } @@ -2005,7 +2005,7 @@ fn divide(m: usize, b: usize) -> NSUInteger { #[allow(clippy::too_many_arguments)] pub fn call_pool2d( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, shape: &[usize], @@ -2022,7 +2022,7 @@ pub fn call_pool2d( let dst_el = out_w * out_h * shape[0] * shape[1]; let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?; let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, @@ -2031,14 +2031,14 @@ pub fn call_pool2d( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_conv_transpose1d( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, dilation: usize, @@ -2061,7 +2061,7 @@ pub fn call_conv_transpose1d( let dst_el = c_out * l_out * b_size; let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?; let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, @@ -2084,7 +2084,7 @@ pub fn call_conv_transpose1d( encoder.use_resource(kernel, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } @@ -2108,7 +2108,7 @@ pub struct CallConvTranspose2dCfg<'a> { #[allow(clippy::too_many_arguments)] pub fn call_conv_transpose2d( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, cfg: CallConvTranspose2dCfg, @@ -2119,7 +2119,7 @@ pub fn call_conv_transpose2d( let dst_el = cfg.c_out * cfg.out_w * cfg.out_h * cfg.b_size; let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?; let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, @@ -2143,14 +2143,14 @@ pub fn call_conv_transpose2d( encoder.use_resource(kernel, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } #[allow(clippy::too_many_arguments)] pub fn call_arg_sort( device: &Device, - command_buffer: &CommandBufferRef, + ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, nrows: usize, @@ -2160,7 +2160,7 @@ pub fn call_arg_sort( dst: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Sort, name)?; - let encoder = command_buffer.new_compute_command_encoder(); + let encoder = ep.encoder(); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (&src, dst, ncols as i64, ncols_pad as i64)); @@ -2180,7 +2180,7 @@ pub fn call_arg_sort( encoder.use_resource(dst, metal::MTLResourceUsage::Write); encoder.set_threadgroup_memory_length(0, (ncols_pad * 4).max(16) as u64); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); + ep.maybe_end_encoding(encoder); Ok(()) } diff --git a/candle-metal-kernels/src/utils.rs b/candle-metal-kernels/src/utils.rs index 194cddf4..4ef2162c 100644 --- a/candle-metal-kernels/src/utils.rs +++ b/candle-metal-kernels/src/utils.rs @@ -160,3 +160,26 @@ macro_rules! set_params { )* ); } + +pub trait EncoderProvider { + fn encoder(&self) -> &ComputeCommandEncoderRef; + fn maybe_end_encoding(&self, enc: &ComputeCommandEncoderRef); +} + +impl EncoderProvider for &metal::CommandBuffer { + fn encoder(&self) -> &ComputeCommandEncoderRef { + self.new_compute_command_encoder() + } + fn maybe_end_encoding(&self, enc: &ComputeCommandEncoderRef) { + enc.end_encoding() + } +} + +impl EncoderProvider for &metal::CommandBufferRef { + fn encoder(&self) -> &ComputeCommandEncoderRef { + self.new_compute_command_encoder() + } + fn maybe_end_encoding(&self, enc: &ComputeCommandEncoderRef) { + enc.end_encoding() + } +} |