diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-07-24 15:29:56 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-07-24 16:29:56 +0200 |
commit | ddafc61055601002622778b7762c15bd60057c1f (patch) | |
tree | 5363cf002fb93d9c0368140c1775721aa06d98bd /candle-metal-kernels | |
parent | a925ae6bc659d1b40570b5068b6913d38e75b12e (diff) | |
download | candle-ddafc61055601002622778b7762c15bd60057c1f.tar.gz candle-ddafc61055601002622778b7762c15bd60057c1f.tar.bz2 candle-ddafc61055601002622778b7762c15bd60057c1f.zip |
Use RAII for terminating the encoding. (#2353)
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 90 | ||||
-rw-r--r-- | candle-metal-kernels/src/utils.rs | 40 |
2 files changed, 69 insertions, 61 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 6f723a93..e0c97962 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1,6 +1,6 @@ use metal::{ - Buffer, CompileOptions, ComputePipelineState, Device, Function, FunctionConstantValues, - Library, MTLDataType, MTLSize, NSUInteger, + Buffer, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, Device, Function, + FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger, }; use std::collections::HashMap; use std::ffi::c_void; @@ -311,6 +311,7 @@ pub fn call_copy2d( ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?; let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, @@ -333,7 +334,6 @@ pub fn call_copy2d( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_threads(grid_dims, group_dims); - ep.maybe_end_encoding(encoder); Ok(()) } @@ -349,6 +349,7 @@ pub fn call_unary_contiguous_tiled( ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); let tile_size = 2; let tiles = (length + tile_size - 1) / tile_size; @@ -360,7 +361,6 @@ pub fn call_unary_contiguous_tiled( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); Ok(()) } @@ -376,6 +376,7 @@ pub fn call_unary_contiguous( ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); @@ -385,7 +386,6 @@ pub fn call_unary_contiguous( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); Ok(()) } @@ -405,6 +405,7 @@ pub fn call_unary_strided( let length: usize = shape.iter().product(); let num_dims: usize = shape.len(); let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.set_compute_pipeline_state(&pipeline); @@ -412,7 +413,6 @@ pub fn call_unary_strided( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); Ok(()) } @@ -430,6 +430,7 @@ pub fn call_binary_contiguous( let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?; let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, &left, &right, output)); @@ -440,7 +441,6 @@ pub fn call_binary_contiguous( encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); Ok(()) } @@ -461,6 +461,7 @@ pub fn call_binary_strided( let num_dims: usize = shape.len(); let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); let width: usize = shape.iter().product(); let length: usize = shape.iter().product(); let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); @@ -483,7 +484,6 @@ pub fn call_binary_strided( encoder.use_resource(right_input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); Ok(()) } @@ -501,6 +501,7 @@ pub fn call_cast_contiguous( let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, &input, output)); @@ -509,7 +510,6 @@ pub fn call_cast_contiguous( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); Ok(()) } @@ -527,6 +527,7 @@ pub fn call_cast_strided( let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); @@ -541,7 +542,6 @@ pub fn call_cast_strided( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); Ok(()) } @@ -560,6 +560,7 @@ pub fn call_reduce_contiguous( let elements_to_sum = length / out_length; let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, elements_to_sum, &input, output)); @@ -585,7 +586,6 @@ pub fn call_reduce_contiguous( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); Ok(()) } @@ -606,6 +606,7 @@ pub fn call_reduce_strided( let elements_to_sum = length / out_length; let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -634,7 +635,6 @@ pub fn call_reduce_strided( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); Ok(()) } @@ -652,6 +652,7 @@ pub fn call_last_softmax( ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -682,7 +683,6 @@ pub fn call_last_softmax( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); Ok(()) } @@ -703,6 +703,7 @@ pub fn call_rms_norm( ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -741,7 +742,6 @@ pub fn call_rms_norm( encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.set_threadgroup_memory_length(0, (width * 4).max(16) as u64); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); Ok(()) } @@ -764,6 +764,7 @@ pub fn call_layer_norm( ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -803,7 +804,6 @@ pub fn call_layer_norm( encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.set_threadgroup_memory_length(0, (width * 8).max(32) as u64); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); Ok(()) } @@ -825,6 +825,7 @@ pub fn call_rope_i( ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -844,7 +845,6 @@ pub fn call_rope_i( encoder.use_resource(sin, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); Ok(()) } @@ -868,6 +868,7 @@ pub fn call_rope_thd( ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -889,7 +890,6 @@ pub fn call_rope_thd( encoder.use_resource(sin, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); Ok(()) } @@ -912,6 +912,7 @@ pub fn call_rope( ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -932,7 +933,6 @@ pub fn call_rope( encoder.use_resource(sin, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); Ok(()) } @@ -951,6 +951,7 @@ pub fn call_affine( let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (size, mul, add, &input, output)); @@ -959,7 +960,6 @@ pub fn call_affine( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); Ok(()) } @@ -980,6 +980,7 @@ pub fn call_affine_strided( let size: usize = shape.iter().product(); let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -1000,7 +1001,6 @@ pub fn call_affine_strided( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); Ok(()) } @@ -1018,6 +1018,7 @@ pub fn call_powf( let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (size, mul, &input, output)); @@ -1026,7 +1027,6 @@ pub fn call_powf( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); Ok(()) } @@ -1046,6 +1046,7 @@ pub fn call_powf_strided( let size: usize = shape.iter().product(); let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -1057,7 +1058,6 @@ pub fn call_powf_strided( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); Ok(()) } @@ -1075,6 +1075,7 @@ pub fn call_elu( let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (size, mul, &input, output)); @@ -1083,7 +1084,6 @@ pub fn call_elu( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); Ok(()) } @@ -1103,6 +1103,7 @@ pub fn call_elu_strided( let size: usize = shape.iter().product(); let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -1114,7 +1115,6 @@ pub fn call_elu_strided( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); Ok(()) } @@ -1136,6 +1136,7 @@ pub fn call_where_cond_strided( let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?; let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); let size: usize = shape.iter().product(); @@ -1164,7 +1165,6 @@ pub fn call_where_cond_strided( encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); Ok(()) } @@ -1192,6 +1192,7 @@ pub fn call_index_select( let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); @@ -1218,7 +1219,6 @@ pub fn call_index_select( encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); Ok(()) } @@ -1243,6 +1243,7 @@ pub fn call_gather( let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); @@ -1266,7 +1267,6 @@ pub fn call_gather( encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); Ok(()) } @@ -1292,6 +1292,7 @@ pub fn call_scatter_add( let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); @@ -1315,7 +1316,6 @@ pub fn call_scatter_add( encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); Ok(()) } @@ -1342,6 +1342,7 @@ pub fn call_index_add( let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); @@ -1366,7 +1367,6 @@ pub fn call_index_add( encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); Ok(()) } @@ -1573,6 +1573,7 @@ pub fn call_gemm( let block_bytes = block_elements * bytes; let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); encoder.set_threadgroup_memory_length(0, block_bytes.into()); encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger); @@ -1615,8 +1616,6 @@ pub fn call_gemm( encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(grid_size, group_size); - ep.maybe_end_encoding(encoder); - Ok(()) } @@ -1637,6 +1636,7 @@ pub fn call_im2col1d_strided( let dst_el = shape[0] * l_out * shape[1] * k_size; let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -1646,8 +1646,6 @@ pub fn call_im2col1d_strided( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); - Ok(()) } @@ -1670,6 +1668,7 @@ pub fn call_col2im1d( let dst_el = shape[0] * c_out * l_out; let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -1679,8 +1678,6 @@ pub fn call_col2im1d( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); - Ok(()) } @@ -1706,6 +1703,7 @@ pub fn call_im2col_strided( let dst_el = shape[0] * h_out * w_out * shape[1] * h_k * w_k; let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -1718,8 +1716,6 @@ pub fn call_im2col_strided( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); - Ok(()) } @@ -1742,6 +1738,7 @@ pub fn call_upsample_nearest_2d( let scale_h = shape[3] as f32 / out_h as f32; let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, @@ -1750,8 +1747,6 @@ pub fn call_upsample_nearest_2d( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); - Ok(()) } @@ -1774,6 +1769,7 @@ pub fn call_random_uniform( } let pipeline = kernels.load_pipeline(device, Source::Random, name)?; let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); let odd = (length % 2 != 0) as usize; let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd); @@ -1788,8 +1784,6 @@ pub fn call_random_uniform( ); encoder.use_resource(buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); - Ok(()) } @@ -1807,6 +1801,7 @@ pub fn call_random_normal( ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Random, name)?; let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); let odd = (length % 2 != 0) as usize; let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd); @@ -1821,8 +1816,6 @@ pub fn call_random_normal( ); encoder.use_resource(buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); - Ok(()) } @@ -1962,6 +1955,7 @@ pub fn call_quantized_matmul_mv_t( let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?; let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -1993,8 +1987,6 @@ pub fn call_quantized_matmul_mv_t( encoder.use_resource(dst, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup); - ep.maybe_end_encoding(encoder); - Ok(()) } @@ -2023,6 +2015,7 @@ pub fn call_pool2d( let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?; let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, @@ -2031,7 +2024,6 @@ pub fn call_pool2d( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); Ok(()) } @@ -2062,6 +2054,7 @@ pub fn call_conv_transpose1d( let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?; let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, @@ -2084,7 +2077,6 @@ pub fn call_conv_transpose1d( encoder.use_resource(kernel, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); Ok(()) } @@ -2120,6 +2112,7 @@ pub fn call_conv_transpose2d( let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?; let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, @@ -2143,7 +2136,6 @@ pub fn call_conv_transpose2d( encoder.use_resource(kernel, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); Ok(()) } @@ -2161,6 +2153,7 @@ pub fn call_arg_sort( ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Sort, name)?; let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (&src, dst, ncols as i64, ncols_pad as i64)); @@ -2180,7 +2173,6 @@ pub fn call_arg_sort( encoder.use_resource(dst, metal::MTLResourceUsage::Write); encoder.set_threadgroup_memory_length(0, (ncols_pad * 4).max(16) as u64); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - ep.maybe_end_encoding(encoder); Ok(()) } diff --git a/candle-metal-kernels/src/utils.rs b/candle-metal-kernels/src/utils.rs index 4ef2162c..b42bcff0 100644 --- a/candle-metal-kernels/src/utils.rs +++ b/candle-metal-kernels/src/utils.rs @@ -162,24 +162,40 @@ macro_rules! set_params { } pub trait EncoderProvider { - fn encoder(&self) -> &ComputeCommandEncoderRef; - fn maybe_end_encoding(&self, enc: &ComputeCommandEncoderRef); + type Encoder<'a>: AsRef<metal::ComputeCommandEncoderRef> + where + Self: 'a; + fn encoder<'a>(&'a self) -> Self::Encoder<'a>; } -impl EncoderProvider for &metal::CommandBuffer { - fn encoder(&self) -> &ComputeCommandEncoderRef { - self.new_compute_command_encoder() +pub struct WrappedEncoder<'a>(&'a ComputeCommandEncoderRef); + +impl<'a> Drop for WrappedEncoder<'a> { + fn drop(&mut self) { + self.0.end_encoding() } - fn maybe_end_encoding(&self, enc: &ComputeCommandEncoderRef) { - enc.end_encoding() +} + +impl<'a> AsRef<metal::ComputeCommandEncoderRef> for WrappedEncoder<'a> { + fn as_ref(&self) -> &metal::ComputeCommandEncoderRef { + &self.0 } } -impl EncoderProvider for &metal::CommandBufferRef { - fn encoder(&self) -> &ComputeCommandEncoderRef { - self.new_compute_command_encoder() +impl EncoderProvider for &metal::CommandBuffer { + type Encoder<'a> = WrappedEncoder<'a> + where + Self: 'a; + fn encoder<'a>(&'a self) -> Self::Encoder<'a> { + WrappedEncoder(self.new_compute_command_encoder()) } - fn maybe_end_encoding(&self, enc: &ComputeCommandEncoderRef) { - enc.end_encoding() +} + +impl EncoderProvider for &metal::CommandBufferRef { + type Encoder<'a> = WrappedEncoder<'a> + where + Self: 'a; + fn encoder<'a>(&'a self) -> Self::Encoder<'a> { + WrappedEncoder(self.new_compute_command_encoder()) } } |