summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-metal-kernels/src/lib.rs240
-rw-r--r--candle-metal-kernels/src/utils.rs23
2 files changed, 143 insertions, 120 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 1815dd32..6f723a93 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -1,6 +1,6 @@
use metal::{
- Buffer, CommandBufferRef, CompileOptions, ComputePipelineState, Device, Function,
- FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
+ Buffer, CompileOptions, ComputePipelineState, Device, Function, FunctionConstantValues,
+ Library, MTLDataType, MTLSize, NSUInteger,
};
use std::collections::HashMap;
use std::ffi::c_void;
@@ -8,7 +8,7 @@ use std::sync::RwLock;
mod utils;
pub use utils::BufferOffset;
-use utils::{get_block_dims, linear_split};
+use utils::{get_block_dims, linear_split, EncoderProvider};
const AFFINE: &str = include_str!("affine.metal");
const INDEXING: &str = include_str!("indexing.metal");
@@ -297,7 +297,7 @@ impl Kernels {
#[allow(clippy::too_many_arguments)]
pub fn call_copy2d(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
name: copy2d::Kernel,
input: &Buffer,
@@ -310,7 +310,7 @@ pub fn call_copy2d(
dst_o_in_bytes: usize,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
@@ -333,14 +333,14 @@ pub fn call_copy2d(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_threads(grid_dims, group_dims);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_unary_contiguous_tiled(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: unary::contiguous_tiled::Kernel,
length: usize,
@@ -348,7 +348,7 @@ pub fn call_unary_contiguous_tiled(
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
let tile_size = 2;
let tiles = (length + tile_size - 1) / tile_size;
@@ -360,14 +360,14 @@ pub fn call_unary_contiguous_tiled(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_unary_contiguous(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: unary::contiguous::Kernel,
length: usize,
@@ -375,7 +375,7 @@ pub fn call_unary_contiguous(
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
@@ -385,14 +385,14 @@ pub fn call_unary_contiguous(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_unary_strided(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
name: unary::strided::Kernel,
shape: &[usize],
@@ -404,7 +404,7 @@ pub fn call_unary_strided(
let length: usize = shape.iter().product();
let num_dims: usize = shape.len();
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.set_compute_pipeline_state(&pipeline);
@@ -412,14 +412,14 @@ pub fn call_unary_strided(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_binary_contiguous(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: binary::contiguous::Kernel,
length: usize,
@@ -429,7 +429,7 @@ pub fn call_binary_contiguous(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, &left, &right, output));
@@ -440,14 +440,14 @@ pub fn call_binary_contiguous(
encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_binary_strided(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
name: binary::strided::Kernel,
shape: &[usize],
@@ -460,7 +460,7 @@ pub fn call_binary_strided(
let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
let num_dims: usize = shape.len();
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
let width: usize = shape.iter().product();
let length: usize = shape.iter().product();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
@@ -483,7 +483,7 @@ pub fn call_binary_strided(
encoder.use_resource(right_input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -491,7 +491,7 @@ pub fn call_binary_strided(
#[allow(clippy::too_many_arguments)]
pub fn call_cast_contiguous(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: &'static str,
length: usize,
@@ -500,7 +500,7 @@ pub fn call_cast_contiguous(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, &input, output));
@@ -509,14 +509,14 @@ pub fn call_cast_contiguous(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_cast_strided(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: &'static str,
shape: &[usize],
@@ -526,7 +526,7 @@ pub fn call_cast_strided(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
@@ -541,14 +541,14 @@ pub fn call_cast_strided(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_reduce_contiguous(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: &'static str,
length: usize,
@@ -559,7 +559,7 @@ pub fn call_reduce_contiguous(
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let elements_to_sum = length / out_length;
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, elements_to_sum, &input, output));
@@ -585,14 +585,14 @@ pub fn call_reduce_contiguous(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_reduce_strided(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: &'static str,
shape: &[usize],
@@ -605,7 +605,7 @@ pub fn call_reduce_strided(
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let elements_to_sum = length / out_length;
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -634,14 +634,14 @@ pub fn call_reduce_strided(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_last_softmax(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: &'static str,
length: usize,
@@ -651,7 +651,7 @@ pub fn call_last_softmax(
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -682,14 +682,14 @@ pub fn call_last_softmax(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_rms_norm(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: &'static str,
length: usize,
@@ -702,7 +702,7 @@ pub fn call_rms_norm(
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -741,14 +741,14 @@ pub fn call_rms_norm(
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.set_threadgroup_memory_length(0, (width * 4).max(16) as u64);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_layer_norm(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: &'static str,
length: usize,
@@ -763,7 +763,7 @@ pub fn call_layer_norm(
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -803,14 +803,14 @@ pub fn call_layer_norm(
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.set_threadgroup_memory_length(0, (width * 8).max(32) as u64);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_rope_i(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: &'static str,
bh: usize,
@@ -824,7 +824,7 @@ pub fn call_rope_i(
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -844,14 +844,14 @@ pub fn call_rope_i(
encoder.use_resource(sin, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_rope_thd(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: &'static str,
b: usize,
@@ -867,7 +867,7 @@ pub fn call_rope_thd(
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -889,14 +889,14 @@ pub fn call_rope_thd(
encoder.use_resource(sin, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_rope(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: &'static str,
bh: usize,
@@ -911,7 +911,7 @@ pub fn call_rope(
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -932,14 +932,14 @@ pub fn call_rope(
encoder.use_resource(sin, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_affine(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
size: usize,
@@ -950,7 +950,7 @@ pub fn call_affine(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, add, &input, output));
@@ -959,14 +959,14 @@ pub fn call_affine(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_affine_strided(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
@@ -979,7 +979,7 @@ pub fn call_affine_strided(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let size: usize = shape.iter().product();
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -1000,14 +1000,14 @@ pub fn call_affine_strided(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_powf(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
size: usize,
@@ -1017,7 +1017,7 @@ pub fn call_powf(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, &input, output));
@@ -1026,14 +1026,14 @@ pub fn call_powf(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_powf_strided(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
@@ -1045,7 +1045,7 @@ pub fn call_powf_strided(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let size: usize = shape.iter().product();
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -1057,14 +1057,14 @@ pub fn call_powf_strided(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_elu(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
size: usize,
@@ -1074,7 +1074,7 @@ pub fn call_elu(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, &input, output));
@@ -1083,14 +1083,14 @@ pub fn call_elu(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_elu_strided(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
@@ -1102,7 +1102,7 @@ pub fn call_elu_strided(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let size: usize = shape.iter().product();
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -1114,14 +1114,14 @@ pub fn call_elu_strided(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_where_cond_strided(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
@@ -1135,7 +1135,7 @@ pub fn call_where_cond_strided(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
let size: usize = shape.iter().product();
@@ -1164,14 +1164,14 @@ pub fn call_where_cond_strided(
encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_index_select(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
@@ -1191,7 +1191,7 @@ pub fn call_index_select(
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
@@ -1218,14 +1218,14 @@ pub fn call_index_select(
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_gather(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
@@ -1242,7 +1242,7 @@ pub fn call_gather(
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
@@ -1266,14 +1266,14 @@ pub fn call_gather(
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_scatter_add(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
src_shape: &[usize],
@@ -1291,7 +1291,7 @@ pub fn call_scatter_add(
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
@@ -1315,14 +1315,14 @@ pub fn call_scatter_add(
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_index_add(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
src_shape: &[usize],
@@ -1341,7 +1341,7 @@ pub fn call_index_add(
let ids_dim_size = ids_shape[0];
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
@@ -1366,7 +1366,7 @@ pub fn call_index_add(
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -1453,7 +1453,7 @@ impl ConstantValues {
#[allow(clippy::too_many_arguments)]
pub fn call_gemm(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
(b, m, n, k): (usize, usize, usize, usize),
@@ -1572,7 +1572,7 @@ pub fn call_gemm(
};
let block_bytes = block_elements * bytes;
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_threadgroup_memory_length(0, block_bytes.into());
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
@@ -1615,7 +1615,7 @@ pub fn call_gemm(
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_size, group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -1623,7 +1623,7 @@ pub fn call_gemm(
#[allow(clippy::too_many_arguments)]
pub fn call_im2col1d_strided(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
@@ -1636,7 +1636,7 @@ pub fn call_im2col1d_strided(
let l_out = (shape[2] + 2 * padding - dilation * (k_size - 1) - 1) / stride + 1;
let dst_el = shape[0] * l_out * shape[1] * k_size;
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -1646,7 +1646,7 @@ pub fn call_im2col1d_strided(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -1654,7 +1654,7 @@ pub fn call_im2col1d_strided(
#[allow(clippy::too_many_arguments)]
pub fn call_col2im1d(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
@@ -1669,7 +1669,7 @@ pub fn call_col2im1d(
let l_out = (l_in - 1) * stride + k_size;
let dst_el = shape[0] * c_out * l_out;
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -1679,7 +1679,7 @@ pub fn call_col2im1d(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -1687,7 +1687,7 @@ pub fn call_col2im1d(
#[allow(clippy::too_many_arguments)]
pub fn call_im2col_strided(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
@@ -1705,7 +1705,7 @@ pub fn call_im2col_strided(
let dst_el = shape[0] * h_out * w_out * shape[1] * h_k * w_k;
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -1718,7 +1718,7 @@ pub fn call_im2col_strided(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -1726,7 +1726,7 @@ pub fn call_im2col_strided(
#[allow(clippy::too_many_arguments)]
pub fn call_upsample_nearest_2d(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
@@ -1741,7 +1741,7 @@ pub fn call_upsample_nearest_2d(
let scale_w = shape[2] as f32 / out_w as f32;
let scale_h = shape[3] as f32 / out_h as f32;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
@@ -1750,7 +1750,7 @@ pub fn call_upsample_nearest_2d(
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -1758,7 +1758,7 @@ pub fn call_upsample_nearest_2d(
#[allow(clippy::too_many_arguments)]
pub fn call_random_uniform(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
min: f32,
@@ -1773,7 +1773,7 @@ pub fn call_random_uniform(
));
}
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
let odd = (length % 2 != 0) as usize;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
@@ -1788,7 +1788,7 @@ pub fn call_random_uniform(
);
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -1796,7 +1796,7 @@ pub fn call_random_uniform(
#[allow(clippy::too_many_arguments)]
pub fn call_random_normal(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
mean: f32,
@@ -1806,7 +1806,7 @@ pub fn call_random_normal(
buffer: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
let odd = (length % 2 != 0) as usize;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
@@ -1821,7 +1821,7 @@ pub fn call_random_normal(
);
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -1847,7 +1847,7 @@ pub enum GgmlDType {
#[allow(clippy::too_many_arguments)]
pub fn call_quantized_matmul_mv_t(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
dtype: GgmlDType,
(b, m, n, k): (usize, usize, usize, usize),
@@ -1961,7 +1961,7 @@ pub fn call_quantized_matmul_mv_t(
};
let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?;
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -1993,7 +1993,7 @@ pub fn call_quantized_matmul_mv_t(
encoder.use_resource(dst, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -2005,7 +2005,7 @@ fn divide(m: usize, b: usize) -> NSUInteger {
#[allow(clippy::too_many_arguments)]
pub fn call_pool2d(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
@@ -2022,7 +2022,7 @@ pub fn call_pool2d(
let dst_el = out_w * out_h * shape[0] * shape[1];
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
@@ -2031,14 +2031,14 @@ pub fn call_pool2d(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_conv_transpose1d(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
dilation: usize,
@@ -2061,7 +2061,7 @@ pub fn call_conv_transpose1d(
let dst_el = c_out * l_out * b_size;
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
@@ -2084,7 +2084,7 @@ pub fn call_conv_transpose1d(
encoder.use_resource(kernel, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
@@ -2108,7 +2108,7 @@ pub struct CallConvTranspose2dCfg<'a> {
#[allow(clippy::too_many_arguments)]
pub fn call_conv_transpose2d(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
cfg: CallConvTranspose2dCfg,
@@ -2119,7 +2119,7 @@ pub fn call_conv_transpose2d(
let dst_el = cfg.c_out * cfg.out_w * cfg.out_h * cfg.b_size;
let pipeline: ComputePipelineState = kernels.load_pipeline(device, Source::Conv, name)?;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
@@ -2143,14 +2143,14 @@ pub fn call_conv_transpose2d(
encoder.use_resource(kernel, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_arg_sort(
device: &Device,
- command_buffer: &CommandBufferRef,
+ ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
nrows: usize,
@@ -2160,7 +2160,7 @@ pub fn call_arg_sort(
dst: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Sort, name)?;
- let encoder = command_buffer.new_compute_command_encoder();
+ let encoder = ep.encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (&src, dst, ncols as i64, ncols_pad as i64));
@@ -2180,7 +2180,7 @@ pub fn call_arg_sort(
encoder.use_resource(dst, metal::MTLResourceUsage::Write);
encoder.set_threadgroup_memory_length(0, (ncols_pad * 4).max(16) as u64);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
+ ep.maybe_end_encoding(encoder);
Ok(())
}
diff --git a/candle-metal-kernels/src/utils.rs b/candle-metal-kernels/src/utils.rs
index 194cddf4..4ef2162c 100644
--- a/candle-metal-kernels/src/utils.rs
+++ b/candle-metal-kernels/src/utils.rs
@@ -160,3 +160,26 @@ macro_rules! set_params {
)*
);
}
+
+pub trait EncoderProvider {
+ fn encoder(&self) -> &ComputeCommandEncoderRef;
+ fn maybe_end_encoding(&self, enc: &ComputeCommandEncoderRef);
+}
+
+impl EncoderProvider for &metal::CommandBuffer {
+ fn encoder(&self) -> &ComputeCommandEncoderRef {
+ self.new_compute_command_encoder()
+ }
+ fn maybe_end_encoding(&self, enc: &ComputeCommandEncoderRef) {
+ enc.end_encoding()
+ }
+}
+
+impl EncoderProvider for &metal::CommandBufferRef {
+ fn encoder(&self) -> &ComputeCommandEncoderRef {
+ self.new_compute_command_encoder()
+ }
+ fn maybe_end_encoding(&self, enc: &ComputeCommandEncoderRef) {
+ enc.end_encoding()
+ }
+}