summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-07-24 15:29:56 +0100
committerGitHub <noreply@github.com>2024-07-24 16:29:56 +0200
commitddafc61055601002622778b7762c15bd60057c1f (patch)
tree5363cf002fb93d9c0368140c1775721aa06d98bd /candle-metal-kernels
parenta925ae6bc659d1b40570b5068b6913d38e75b12e (diff)
downloadcandle-ddafc61055601002622778b7762c15bd60057c1f.tar.gz
candle-ddafc61055601002622778b7762c15bd60057c1f.tar.bz2
candle-ddafc61055601002622778b7762c15bd60057c1f.zip
Use RAII for terminating the encoding. (#2353)
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/lib.rs90
-rw-r--r--candle-metal-kernels/src/utils.rs40
2 files changed, 69 insertions, 61 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 6f723a93..e0c97962 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -1,6 +1,6 @@
use metal::{
- Buffer, CompileOptions, ComputePipelineState, Device, Function, FunctionConstantValues,
- Library, MTLDataType, MTLSize, NSUInteger,
+ Buffer, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, Device, Function,
+ FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
};
use std::collections::HashMap;
use std::ffi::c_void;
@@ -311,6 +311,7 @@ pub fn call_copy2d(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
@@ -333,7 +334,6 @@ pub fn call_copy2d(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_threads(grid_dims, group_dims);
- ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -349,6 +349,7 @@ pub fn call_unary_contiguous_tiled(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
let tile_size = 2;
let tiles = (length + tile_size - 1) / tile_size;
@@ -360,7 +361,6 @@ pub fn call_unary_contiguous_tiled(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -376,6 +376,7 @@ pub fn call_unary_contiguous(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
@@ -385,7 +386,6 @@ pub fn call_unary_contiguous(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -405,6 +405,7 @@ pub fn call_unary_strided(
let length: usize = shape.iter().product();
let num_dims: usize = shape.len();
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.set_compute_pipeline_state(&pipeline);
@@ -412,7 +413,6 @@ pub fn call_unary_strided(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -430,6 +430,7 @@ pub fn call_binary_contiguous(
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, &left, &right, output));
@@ -440,7 +441,6 @@ pub fn call_binary_contiguous(
encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -461,6 +461,7 @@ pub fn call_binary_strided(
let num_dims: usize = shape.len();
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
let width: usize = shape.iter().product();
let length: usize = shape.iter().product();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
@@ -483,7 +484,6 @@ pub fn call_binary_strided(
encoder.use_resource(right_input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -501,6 +501,7 @@ pub fn call_cast_contiguous(
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, &input, output));
@@ -509,7 +510,6 @@ pub fn call_cast_contiguous(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -527,6 +527,7 @@ pub fn call_cast_strided(
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
@@ -541,7 +542,6 @@ pub fn call_cast_strided(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -560,6 +560,7 @@ pub fn call_reduce_contiguous(
let elements_to_sum = length / out_length;
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, elements_to_sum, &input, output));
@@ -585,7 +586,6 @@ pub fn call_reduce_contiguous(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -606,6 +606,7 @@ pub fn call_reduce_strided(
let elements_to_sum = length / out_length;
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -634,7 +635,6 @@ pub fn call_reduce_strided(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -652,6 +652,7 @@ pub fn call_last_softmax(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -682,7 +683,6 @@ pub fn call_last_softmax(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -703,6 +703,7 @@ pub fn call_rms_norm(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -741,7 +742,6 @@ pub fn call_rms_norm(
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.set_threadgroup_memory_length(0, (width * 4).max(16) as u64);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -764,6 +764,7 @@ pub fn call_layer_norm(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -803,7 +804,6 @@ pub fn call_layer_norm(
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.set_threadgroup_memory_length(0, (width * 8).max(32) as u64);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -825,6 +825,7 @@ pub fn call_rope_i(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -844,7 +845,6 @@ pub fn call_rope_i(
encoder.use_resource(sin, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -868,6 +868,7 @@ pub fn call_rope_thd(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -889,7 +890,6 @@ pub fn call_rope_thd(
encoder.use_resource(sin, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -912,6 +912,7 @@ pub fn call_rope(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -932,7 +933,6 @@ pub fn call_rope(
encoder.use_resource(sin, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -951,6 +951,7 @@ pub fn call_affine(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, add, &input, output));
@@ -959,7 +960,6 @@ pub fn call_affine(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -980,6 +980,7 @@ pub fn call_affine_strided(
let size: usize = shape.iter().product();
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -1000,7 +1001,6 @@ pub fn call_affine_strided(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -1018,6 +1018,7 @@ pub fn call_powf(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, &input, output));
@@ -1026,7 +1027,6 @@ pub fn call_powf(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -1046,6 +1046,7 @@ pub fn call_powf_strided(
let size: usize = shape.iter().product();
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -1057,7 +1058,6 @@ pub fn call_powf_strided(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -1075,6 +1075,7 @@ pub fn call_elu(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, &input, output));
@@ -1083,7 +1084,6 @@ pub fn call_elu(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -1103,6 +1103,7 @@ pub fn call_elu_strided(
let size: usize = shape.iter().product();
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -1114,7 +1115,6 @@ pub fn call_elu_strided(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -1136,6 +1136,7 @@ pub fn call_where_cond_strided(
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
let size: usize = shape.iter().product();
@@ -1164,7 +1165,6 @@ pub fn call_where_cond_strided(
encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -1192,6 +1192,7 @@ pub fn call_index_select(
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
@@ -1218,7 +1219,6 @@ pub fn call_index_select(
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -1243,6 +1243,7 @@ pub fn call_gather(
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
@@ -1266,7 +1267,6 @@ pub fn call_gather(
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -1292,6 +1292,7 @@ pub fn call_scatter_add(
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
@@ -1315,7 +1316,6 @@ pub fn call_scatter_add(
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -1342,6 +1342,7 @@ pub fn call_index_add(
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
@@ -1366,7 +1367,6 @@ pub fn call_index_add(
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -1573,6 +1573,7 @@ pub fn call_gemm(
let block_bytes = block_elements * bytes;
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_threadgroup_memory_length(0, block_bytes.into());
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
@@ -1615,8 +1616,6 @@ pub fn call_gemm(
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_size, group_size);
- ep.maybe_end_encoding(encoder);
-
Ok(())
}
@@ -1637,6 +1636,7 @@ pub fn call_im2col1d_strided(
let dst_el = shape[0] * l_out * shape[1] * k_size;
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -1646,8 +1646,6 @@ pub fn call_im2col1d_strided(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
-
Ok(())
}
@@ -1670,6 +1668,7 @@ pub fn call_col2im1d(
let dst_el = shape[0] * c_out * l_out;
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -1679,8 +1678,6 @@ pub fn call_col2im1d(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
-
Ok(())
}
@@ -1706,6 +1703,7 @@ pub fn call_im2col_strided(
let dst_el = shape[0] * h_out * w_out * shape[1] * h_k * w_k;
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -1718,8 +1716,6 @@ pub fn call_im2col_strided(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
-
Ok(())
}
@@ -1742,6 +1738,7 @@ pub fn call_upsample_nearest_2d(
let scale_h = shape[3] as f32 / out_h as f32;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
@@ -1750,8 +1747,6 @@ pub fn call_upsample_nearest_2d(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
-
Ok(())
}
@@ -1774,6 +1769,7 @@ pub fn call_random_uniform(
}
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
let odd = (length % 2 != 0) as usize;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
@@ -1788,8 +1784,6 @@ pub fn call_random_uniform(
);
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
-
Ok(())
}
@@ -1807,6 +1801,7 @@ pub fn call_random_normal(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
let odd = (length % 2 != 0) as usize;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
@@ -1821,8 +1816,6 @@ pub fn call_random_normal(
);
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
-
Ok(())
}
@@ -1962,6 +1955,7 @@ pub fn call_quantized_matmul_mv_t(
let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?;
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -1993,8 +1987,6 @@ pub fn call_quantized_matmul_mv_t(
encoder.use_resource(dst, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup);
- ep.maybe_end_encoding(encoder);
-
Ok(())
}
@@ -2023,6 +2015,7 @@ pub fn call_pool2d(
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
@@ -2031,7 +2024,6 @@ pub fn call_pool2d(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -2062,6 +2054,7 @@ pub fn call_conv_transpose1d(
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
@@ -2084,7 +2077,6 @@ pub fn call_conv_transpose1d(
encoder.use_resource(kernel, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -2120,6 +2112,7 @@ pub fn call_conv_transpose2d(
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
@@ -2143,7 +2136,6 @@ pub fn call_conv_transpose2d(
encoder.use_resource(kernel, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -2161,6 +2153,7 @@ pub fn call_arg_sort(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Sort, name)?;
let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (&src, dst, ncols as i64, ncols_pad as i64));
@@ -2180,7 +2173,6 @@ pub fn call_arg_sort(
encoder.use_resource(dst, metal::MTLResourceUsage::Write);
encoder.set_threadgroup_memory_length(0, (ncols_pad * 4).max(16) as u64);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- ep.maybe_end_encoding(encoder);
Ok(())
}
diff --git a/candle-metal-kernels/src/utils.rs b/candle-metal-kernels/src/utils.rs
index 4ef2162c..b42bcff0 100644
--- a/candle-metal-kernels/src/utils.rs
+++ b/candle-metal-kernels/src/utils.rs
@@ -162,24 +162,40 @@ macro_rules! set_params {
}
pub trait EncoderProvider {
- fn encoder(&self) -> &ComputeCommandEncoderRef;
- fn maybe_end_encoding(&self, enc: &ComputeCommandEncoderRef);
+ type Encoder<'a>: AsRef<metal::ComputeCommandEncoderRef>
+ where
+ Self: 'a;
+ fn encoder<'a>(&'a self) -> Self::Encoder<'a>;
}
-impl EncoderProvider for &metal::CommandBuffer {
- fn encoder(&self) -> &ComputeCommandEncoderRef {
- self.new_compute_command_encoder()
+pub struct WrappedEncoder<'a>(&'a ComputeCommandEncoderRef);
+
+impl<'a> Drop for WrappedEncoder<'a> {
+ fn drop(&mut self) {
+ self.0.end_encoding()
}
- fn maybe_end_encoding(&self, enc: &ComputeCommandEncoderRef) {
- enc.end_encoding()
+}
+
+impl<'a> AsRef<metal::ComputeCommandEncoderRef> for WrappedEncoder<'a> {
+ fn as_ref(&self) -> &metal::ComputeCommandEncoderRef {
+ &self.0
}
}
-impl EncoderProvider for &metal::CommandBufferRef {
- fn encoder(&self) -> &ComputeCommandEncoderRef {
- self.new_compute_command_encoder()
+impl EncoderProvider for &metal::CommandBuffer {
+ type Encoder<'a> = WrappedEncoder<'a>
+ where
+ Self: 'a;
+ fn encoder<'a>(&'a self) -> Self::Encoder<'a> {
+ WrappedEncoder(self.new_compute_command_encoder())
}
- fn maybe_end_encoding(&self, enc: &ComputeCommandEncoderRef) {
- enc.end_encoding()
+}
+
+impl EncoderProvider for &metal::CommandBufferRef {
+ type Encoder<'a> = WrappedEncoder<'a>
+ where
+ Self: 'a;
+ fn encoder<'a>(&'a self) -> Self::Encoder<'a> {
+ WrappedEncoder(self.new_compute_command_encoder())
}
}