summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-metal-kernels/src/lib.rs')
-rw-r--r--candle-metal-kernels/src/lib.rs1045
1 files changed, 898 insertions, 147 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 5a6bd41b..0418c96c 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -1,6 +1,6 @@
use metal::{
- Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor,
- ComputePipelineState, Device, Function, Library, MTLSize,
+ Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
+ Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
};
use std::collections::HashMap;
use std::ffi::c_void;
@@ -13,7 +13,12 @@ const BINARY: &str = include_str!("binary.metal");
const TERNARY: &str = include_str!("ternary.metal");
const CAST: &str = include_str!("cast.metal");
const REDUCE: &str = include_str!("reduce.metal");
+const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
+/// Most kernels apply similarly across the tensors
+/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
+/// actual total buffer length).
+/// Then kernels can just do their op on their single point in the buffer.
fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
let size = length as u64;
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size);
@@ -35,6 +40,10 @@ fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTL
fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {
<P as EncoderParam>::set_param(encoder, position, data)
}
+
+/// Helper functions to create the various objects on the compute command encoder
+/// on a single line.
+/// Prevents getting wrong some arguments number and mixing length and size in bytes.
trait EncoderParam {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
}
@@ -59,8 +68,8 @@ impl<T> EncoderParam for &[T] {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_bytes(
position,
- (core::mem::size_of::<T>() * data.len()) as u64,
- data.as_ptr() as *const T as *const c_void,
+ core::mem::size_of_val(data) as u64,
+ data.as_ptr() as *const c_void,
);
}
}
@@ -105,54 +114,59 @@ pub enum Source {
Ternary,
Cast,
Reduce,
+ Mfa,
}
macro_rules! ops{
($($name:ident),+) => {
pub mod contiguous {
- #[derive(Clone, Copy)]
- pub struct Kernel(pub(crate) &'static str);
- impl std::fmt::Display for Kernel {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "{}", self.0)
- }
- }
+ pub struct Kernel(pub &'static str);
$(
pub mod $name {
use super::Kernel;
- pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float"));
- pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half"));
- pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat"));
+ pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32"));
+ pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16"));
+ pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16"));
}
)+
+ pub mod copy {
+ use super::Kernel;
+ pub const FLOAT: Kernel = Kernel("copy_f32");
+ pub const HALF: Kernel = Kernel("copy_f16");
+ pub const BFLOAT: Kernel = Kernel("copy_bf16");
+ pub const U32: Kernel = Kernel("copy_u32");
+ pub const U8: Kernel = Kernel("copy_u8");
+ }
}
pub mod strided {
- #[derive(Clone, Copy)]
- pub struct Kernel(pub(crate) &'static str);
- impl std::fmt::Display for Kernel {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "{}", self.0)
- }
- }
+ pub struct Kernel(pub &'static str);
$(
pub mod $name {
use super::Kernel;
- pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float_strided"));
- pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half_strided"));
- pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided"));
+ pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_strided"));
+ pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided"));
+ pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided"));
}
)+
+ pub mod copy {
+ use super::Kernel;
+ pub const FLOAT: Kernel = Kernel("copy_f32_strided");
+ pub const HALF: Kernel = Kernel("copy_f16_strided");
+ pub const BFLOAT: Kernel = Kernel("copy_bf16_strided");
+ pub const U32: Kernel = Kernel("copy_u32_strided");
+ pub const U8: Kernel = Kernel("copy_u8_strided");
+ }
}
};
}
pub mod unary {
- ops!(cos, sin, exp, sqr, sqrt, neg, copy, log);
+ ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh);
}
pub mod binary {
- ops!(add, sub, mul, div);
+ ops!(add, sub, mul, div, min, max, eq, ne, le, lt, ge, gt);
}
#[derive(thiserror::Error, Debug)]
@@ -161,8 +175,18 @@ pub enum MetalKernelError {
LockError(String),
#[error("Error while loading library: {0}")]
LoadLibraryError(String),
- #[error("Error while loading function: {0}")]
+ #[error("Error while loading function: {0:?}")]
LoadFunctionError(String),
+ #[error("Failed to create compute function")]
+ FailedToCreateComputeFunction,
+ #[error("Failed to create pipeline")]
+ FailedToCreatePipeline(String),
+ #[error("Invalid matmul arguments {lhs_stride:?} {rhs_stride:?} {mnk:?}")]
+ MatMulNonContiguous {
+ lhs_stride: Vec<usize>,
+ rhs_stride: Vec<usize>,
+ mnk: (usize, usize, usize),
+ },
}
impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
@@ -171,21 +195,25 @@ impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
}
}
-type KernelMap<T> = HashMap<&'static str, T>;
type Libraries = HashMap<Source, Library>;
-type Functions = KernelMap<Function>;
+type Pipelines = HashMap<(&'static str, Option<ConstantValues>), ComputePipelineState>;
-#[derive(Debug, Default)]
+#[derive(Debug)]
pub struct Kernels {
libraries: RwLock<Libraries>,
- funcs: RwLock<Functions>,
+ pipelines: RwLock<Pipelines>,
+ fence: metal::Fence,
}
impl Kernels {
- pub fn new() -> Self {
+ pub fn new(fence: metal::Fence) -> Self {
let libraries = RwLock::new(Libraries::new());
- let funcs = RwLock::new(Functions::new());
- Self { libraries, funcs }
+ let pipelines = RwLock::new(Pipelines::new());
+ Self {
+ libraries,
+ pipelines,
+ fence,
+ }
}
fn get_library_source(&self, source: Source) -> &'static str {
@@ -197,9 +225,12 @@ impl Kernels {
Source::Indexing => INDEXING,
Source::Cast => CAST,
Source::Reduce => REDUCE,
+ Source::Mfa => panic!("Invalid lib"),
}
}
+ /// Load the give library from its [`source`].
+ /// If this has been previously loaded it will just fetch it from cache.
pub fn load_library(
&self,
device: &Device,
@@ -209,33 +240,83 @@ impl Kernels {
if let Some(lib) = libraries.get(&source) {
Ok(lib.clone())
} else {
- let source_content = self.get_library_source(source);
- let lib = device
- .new_library_with_source(source_content, &CompileOptions::new())
- .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?;
+ let lib = match source {
+ Source::Mfa => {
+ let source_data = MFA;
+ device.new_library_with_data(source_data).map_err(|e| {
+ MetalKernelError::LoadLibraryError(format!(
+ "Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
+ ))
+ })?
+ }
+ source => {
+ let source_content = self.get_library_source(source);
+ device
+ .new_library_with_source(source_content, &CompileOptions::new())
+ .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
+ }
+ };
libraries.insert(source, lib.clone());
Ok(lib)
}
}
- pub fn load_function(
+ fn load_function(
&self,
device: &Device,
source: Source,
name: &'static str,
+ constants: Option<FunctionConstantValues>,
) -> Result<Function, MetalKernelError> {
- let mut funcs = self.funcs.write()?;
- if let Some(func) = funcs.get(name) {
- Ok(func.clone())
+ let func = self
+ .load_library(device, source)?
+ .get_function(name, constants)
+ .map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
+ Ok(func)
+ }
+
+ /// Load the give pipeline
+ /// loads the library from source, then gets the function [`name`] from
+ /// that source
+ fn load_pipeline_with_constants(
+ &self,
+ device: &Device,
+ source: Source,
+ name: &'static str,
+ constants: Option<ConstantValues>,
+ ) -> Result<ComputePipelineState, MetalKernelError> {
+ let mut pipelines = self.pipelines.write()?;
+ let key = (name, constants);
+ if let Some(pipeline) = pipelines.get(&key) {
+ Ok(pipeline.clone())
} else {
- let func = self
- .load_library(device, source)?
- .get_function(name, None)
- .map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
- funcs.insert(name, func.clone());
- Ok(func)
+ let (name, constants) = key;
+ let func = self.load_function(
+ device,
+ source,
+ name,
+ constants.as_ref().map(|c| c.function_constant_values()),
+ )?;
+ let pipeline = device
+ .new_compute_pipeline_state_with_function(&func)
+ .map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?;
+ pipelines.insert((name, constants), pipeline.clone());
+
+ Ok(pipeline)
}
}
+
+ /// Load the give pipeline
+ /// loads the library from source, then gets the function [`name`] from
+ /// that source (without constants)
+ pub fn load_pipeline(
+ &self,
+ device: &Device,
+ source: Source,
+ name: &'static str,
+ ) -> Result<ComputePipelineState, MetalKernelError> {
+ self.load_pipeline_with_constants(device, source, name, None)
+ }
}
#[allow(clippy::too_many_arguments)]
@@ -246,25 +327,20 @@ pub fn call_unary_contiguous(
kernel_name: unary::contiguous::Kernel,
length: usize,
input: &Buffer,
- output: &mut Buffer,
+ output: &Buffer,
) -> Result<(), MetalKernelError> {
- let func = kernels.load_function(device, Source::Unary, kernel_name.0)?;
- let pipeline_state_descriptor = ComputePipelineDescriptor::new();
- pipeline_state_descriptor.set_compute_function(Some(&func));
-
- let pipeline = device
- .new_compute_pipeline_state_with_function(
- pipeline_state_descriptor.compute_function().unwrap(),
- )
- .unwrap();
-
+ let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -279,21 +355,14 @@ pub fn call_unary_strided(
input: &Buffer,
strides: &[usize],
offset: usize,
- output: &mut Buffer,
+ output: &Buffer,
output_offset: usize,
) -> Result<(), MetalKernelError> {
- let func = kernels.load_function(device, Source::Unary, name.0)?;
- let pipeline_state_descriptor = ComputePipelineDescriptor::new();
- pipeline_state_descriptor.set_compute_function(Some(&func));
-
- let pipeline = device
- .new_compute_pipeline_state_with_function(
- pipeline_state_descriptor.compute_function().unwrap(),
- )
- .unwrap();
+ let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
@@ -312,7 +381,10 @@ pub fn call_unary_strided(
let width: usize = shape.iter().product();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -326,26 +398,23 @@ pub fn call_binary_contiguous(
length: usize,
left: &Buffer,
right: &Buffer,
- output: &mut Buffer,
+ output: &Buffer,
) -> Result<(), MetalKernelError> {
- let func = kernels.load_function(device, Source::Binary, kernel_name.0)?;
- let pipeline_state_descriptor = ComputePipelineDescriptor::new();
- pipeline_state_descriptor.set_compute_function(Some(&func));
-
- let pipeline = device
- .new_compute_pipeline_state_with_function(
- pipeline_state_descriptor.compute_function().unwrap(),
- )
- .unwrap();
+ let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, left, right, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
+ encoder.use_resource(left, metal::MTLResourceUsage::Read);
+ encoder.use_resource(right, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -363,21 +432,14 @@ pub fn call_binary_strided(
right_input: &Buffer,
right_strides: &[usize],
right_offset: usize,
- output: &mut Buffer,
+ output: &Buffer,
) -> Result<(), MetalKernelError> {
- let func = kernels.load_function(device, Source::Binary, name.0)?;
- let pipeline_state_descriptor = ComputePipelineDescriptor::new();
- pipeline_state_descriptor.set_compute_function(Some(&func));
-
- let pipeline = device
- .new_compute_pipeline_state_with_function(
- pipeline_state_descriptor.compute_function().unwrap(),
- )
- .unwrap();
+ let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder();
let width: usize = shape.iter().product();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
@@ -398,7 +460,11 @@ pub fn call_binary_strided(
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
+ encoder.use_resource(left_input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(right_input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -411,31 +477,68 @@ pub fn call_cast_contiguous(
kernel_name: &'static str,
length: usize,
input: &Buffer,
- output: &mut Buffer,
+ input_offset: usize,
+ output: &Buffer,
) -> Result<(), MetalKernelError> {
- let func = kernels.load_function(device, Source::Cast, kernel_name)?;
- let pipeline_state_descriptor = ComputePipelineDescriptor::new();
- pipeline_state_descriptor.set_compute_function(Some(&func));
+ let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
- let pipeline = device
- .new_compute_pipeline_state_with_function(
- pipeline_state_descriptor.compute_function().unwrap(),
- )
- .unwrap();
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(encoder, (length, (input, input_offset), output));
+
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
+ encoder.end_encoding();
+ Ok(())
+}
+
+#[allow(clippy::too_many_arguments)]
+pub fn call_cast_strided(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ kernel_name: &'static str,
+ shape: &[usize],
+ input: &Buffer,
+ input_strides: &[usize],
+ input_offset: usize,
+ output: &Buffer,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
- set_params!(encoder, (length, input, output));
+ let length: usize = shape.iter().product();
+
+ set_params!(
+ encoder,
+ (
+ length,
+ shape.len(),
+ shape,
+ input_strides,
+ (input, input_offset),
+ output
+ )
+ );
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
-#[allow(clippy::too_many_arguments)]
pub fn call_reduce_contiguous(
device: &Device,
command_buffer: &CommandBufferRef,
@@ -444,24 +547,78 @@ pub fn call_reduce_contiguous(
length: usize,
out_length: usize,
input: &Buffer,
- output: &mut Buffer,
+ input_offset: usize,
+ output: &Buffer,
) -> Result<(), MetalKernelError> {
- let func = kernels.load_function(device, Source::Reduce, kernel_name)?;
- let pipeline_state_descriptor = ComputePipelineDescriptor::new();
- pipeline_state_descriptor.set_compute_function(Some(&func));
+ let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
+ let elements_to_sum = length / out_length;
- let pipeline = device
- .new_compute_pipeline_state_with_function(
- pipeline_state_descriptor.compute_function().unwrap(),
- )
- .unwrap();
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(
+ encoder,
+ (length, elements_to_sum, (input, input_offset), output)
+ );
+
+ let thread_group_count = MTLSize {
+ width: out_length as u64,
+ height: 1,
+ depth: 1,
+ };
+
+ let width = std::cmp::min(
+ pipeline.max_total_threads_per_threadgroup(),
+ (elements_to_sum as u64 + 2 - 1) / 2,
+ )
+ .next_power_of_two();
+ let thread_group_size = MTLSize {
+ width,
+ height: 1,
+ depth: 1,
+ };
+
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
+ encoder.end_encoding();
+ Ok(())
+}
+
+pub fn call_reduce_strided(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ kernel_name: &'static str,
+ shape: &[usize],
+ strides: &[usize],
+ out_length: usize,
+ input: &Buffer,
+ input_offset: usize,
+ output: &Buffer,
+) -> Result<(), MetalKernelError> {
+ let length: usize = shape.iter().product();
+ let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let elements_to_sum = length / out_length;
let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
- set_params!(encoder, (length, elements_to_sum, input, output));
+ set_params!(
+ encoder,
+ (
+ shape.len(),
+ shape,
+ strides,
+ elements_to_sum,
+ (input, input_offset),
+ output
+ )
+ );
let thread_group_count = MTLSize {
width: out_length as u64,
@@ -471,7 +628,7 @@ pub fn call_reduce_contiguous(
let width = std::cmp::min(
pipeline.max_total_threads_per_threadgroup(),
- (elements_to_sum as u64 + 2 - 1) / 2,
+ elements_to_sum as u64,
)
.next_power_of_two();
@@ -481,7 +638,10 @@ pub fn call_reduce_contiguous(
depth: 1,
};
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -495,22 +655,18 @@ pub fn call_last_softmax(
length: usize,
elements_to_sum: usize,
input: &Buffer,
- output: &mut Buffer,
+ input_offset: usize,
+ output: &Buffer,
) -> Result<(), MetalKernelError> {
- let func = kernels.load_function(device, Source::Reduce, kernel_name)?;
- let pipeline_state_descriptor = ComputePipelineDescriptor::new();
- pipeline_state_descriptor.set_compute_function(Some(&func));
-
- let pipeline = device
- .new_compute_pipeline_state_with_function(
- pipeline_state_descriptor.compute_function().unwrap(),
- )
- .unwrap();
-
+ let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
- set_params!(encoder, (length, elements_to_sum, input, output));
+ set_params!(
+ encoder,
+ (length, elements_to_sum, (input, input_offset), output)
+ );
let out_length = length / elements_to_sum;
@@ -532,7 +688,10 @@ pub fn call_last_softmax(
depth: 1,
};
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -542,34 +701,214 @@ pub fn call_affine(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
+ name: &'static str,
size: usize,
input: &Buffer,
- output: &mut Buffer,
+ output: &Buffer,
+ mul: f32,
+ add: f32,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
+
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(encoder, (size, mul, add, input, output));
+
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
+ encoder.end_encoding();
+ Ok(())
+}
+
+#[allow(clippy::too_many_arguments)]
+pub fn call_affine_strided(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: &'static str,
+ shape: &[usize],
+ input: &Buffer,
+ input_stride: &[usize],
+ input_offset: usize,
+ output: &Buffer,
mul: f32,
add: f32,
) -> Result<(), MetalKernelError> {
- let func = kernels.load_function(device, Source::Affine, "affine_float")?;
- let pipeline_state_descriptor = ComputePipelineDescriptor::new();
- pipeline_state_descriptor.set_compute_function(Some(&func));
+ let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
+ let size: usize = shape.iter().product();
+
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
+ encoder.set_compute_pipeline_state(&pipeline);
- let pipeline = device
- .new_compute_pipeline_state_with_function(
- pipeline_state_descriptor.compute_function().unwrap(),
+ set_params!(
+ encoder,
+ (
+ size,
+ shape.len(),
+ shape,
+ input_stride,
+ mul,
+ add,
+ (input, input_offset),
+ output
)
- .unwrap();
+ );
+
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
+ encoder.end_encoding();
+ Ok(())
+}
+
+#[allow(clippy::too_many_arguments)]
+pub fn call_powf(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: &'static str,
+ size: usize,
+ input: &Buffer,
+ output: &Buffer,
+ mul: f32,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
- set_params!(encoder, (size, mul, add, input, output));
+ set_params!(encoder, (size, mul, input, output));
+
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
+ encoder.end_encoding();
+ Ok(())
+}
+
+#[allow(clippy::too_many_arguments)]
+pub fn call_powf_strided(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: &'static str,
+ shape: &[usize],
+ input: &Buffer,
+ input_stride: &[usize],
+ input_offset: usize,
+ output: &Buffer,
+ mul: f32,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
+ let size: usize = shape.iter().product();
+
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(
+ encoder,
+ (
+ size,
+ shape.len(),
+ shape,
+ input_stride,
+ mul,
+ (input, input_offset),
+ output
+ )
+ );
+
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
+ encoder.end_encoding();
+ Ok(())
+}
+
+#[allow(clippy::too_many_arguments)]
+pub fn call_elu(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: &'static str,
+ size: usize,
+ input: &Buffer,
+ output: &Buffer,
+ mul: f32,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
+
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(encoder, (size, mul, input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
+pub fn call_elu_strided(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: &'static str,
+ shape: &[usize],
+ input: &Buffer,
+ input_stride: &[usize],
+ input_offset: usize,
+ output: &Buffer,
+ mul: f32,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
+ let size: usize = shape.iter().product();
+
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(
+ encoder,
+ (
+ size,
+ shape.len(),
+ shape,
+ input_stride,
+ mul,
+ (input, input_offset),
+ output
+ )
+ );
+
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
+ encoder.end_encoding();
+ Ok(())
+}
+
pub fn call_where_cond_strided(
device: &Device,
command_buffer: &CommandBufferRef,
@@ -582,19 +921,12 @@ pub fn call_where_cond_strided(
(left_stride, left_offset): (&[usize], usize),
right: &Buffer,
(right_stride, right_offset): (&[usize], usize),
- output: &mut Buffer,
+ output: &Buffer,
) -> Result<(), MetalKernelError> {
- let func = kernels.load_function(device, Source::Ternary, name)?;
- let pipeline_state_descriptor = ComputePipelineDescriptor::new();
- pipeline_state_descriptor.set_compute_function(Some(&func));
-
- let pipeline = device
- .new_compute_pipeline_state_with_function(
- pipeline_state_descriptor.compute_function().unwrap(),
- )
- .unwrap();
+ let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
let size: usize = shape.iter().product();
@@ -618,7 +950,12 @@ pub fn call_where_cond_strided(
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
+ encoder.use_resource(cond, metal::MTLResourceUsage::Read);
+ encoder.use_resource(left, metal::MTLResourceUsage::Read);
+ encoder.use_resource(right, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -634,20 +971,18 @@ pub fn call_index_select(
dim: usize,
input: &Buffer,
ids: &Buffer,
- output: &mut Buffer,
+ output: &Buffer,
) -> Result<(), MetalKernelError> {
let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product();
let src_dim_size = shape[dim];
let dst_el = ids_size * left_size * right_size;
- let func = kernels.load_function(device, Source::Indexing, name)?;
- let pipeline = device
- .new_compute_pipeline_state_with_function(&func)
- .unwrap();
+ let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -666,10 +1001,426 @@ pub fn call_index_select(
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(ids, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
+#[allow(clippy::too_many_arguments)]
+pub fn call_gather(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: &'static str,
+ shape: &[usize],
+ ids_size: usize,
+ dim: usize,
+ input: &Buffer,
+ input_offset: usize,
+ ids: &Buffer,
+ ids_offset: usize,
+ output: &Buffer,
+) -> Result<(), MetalKernelError> {
+ let left_size: usize = shape[..dim].iter().product();
+ let right_size: usize = shape[dim + 1..].iter().product();
+ let src_dim_size = shape[dim];
+ let dst_el = ids_size * left_size * right_size;
+
+ let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
+
+ let encoder = command_buffer.new_compute_command_encoder();
+
+ encoder.wait_for_fence(&kernels.fence);
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(
+ encoder,
+ (
+ dst_el,
+ left_size,
+ src_dim_size,
+ right_size,
+ ids_size,
+ (input, input_offset),
+ (ids, ids_offset),
+ output
+ )
+ );
+
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
+
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(ids, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
+ encoder.end_encoding();
+ Ok(())
+}
+
+pub fn call_scatter_add(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: &'static str,
+ src_shape: &[usize],
+ dst_shape: &[usize],
+ dim: usize,
+ input: &Buffer,
+ input_offset: usize,
+ ids: &Buffer,
+ ids_offset: usize,
+ output: &Buffer,
+) -> Result<(), MetalKernelError> {
+ let left_size: usize = src_shape[..dim].iter().product();
+ let right_size: usize = src_shape[dim + 1..].iter().product();
+ let src_dim_size = src_shape[dim];
+ let dst_el = left_size * right_size;
+ let dst_dim_size = dst_shape[dim];
+
+ let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
+
+ let encoder = command_buffer.new_compute_command_encoder();
+
+ encoder.wait_for_fence(&kernels.fence);
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(
+ encoder,
+ (
+ dst_el,
+ left_size,
+ src_dim_size,
+ right_size,
+ dst_dim_size,
+ (input, input_offset),
+ (ids, ids_offset),
+ output
+ )
+ );
+
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
+
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(ids, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
+ encoder.end_encoding();
+ Ok(())
+}
+
+pub fn call_index_add(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: &'static str,
+ src_shape: &[usize],
+ dst_shape: &[usize],
+ ids_shape: &[usize],
+ dim: usize,
+ input: &Buffer,
+ input_offset: usize,
+ ids: &Buffer,
+ ids_offset: usize,
+ output: &Buffer,
+) -> Result<(), MetalKernelError> {
+ let left_size: usize = src_shape[..dim].iter().product();
+ let right_size: usize = src_shape[dim + 1..].iter().product();
+ let src_dim_size = src_shape[dim];
+ let dst_el = left_size * right_size;
+ let dst_dim_size = dst_shape[dim];
+ let ids_dim_size = ids_shape[0];
+
+ let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
+ let encoder = command_buffer.new_compute_command_encoder();
+
+ encoder.wait_for_fence(&kernels.fence);
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(
+ encoder,
+ (
+ dst_el,
+ left_size,
+ src_dim_size,
+ right_size,
+ dst_dim_size,
+ ids_dim_size,
+ (input, input_offset),
+ (ids, ids_offset),
+ output
+ )
+ );
+
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
+
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(ids, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
+ encoder.end_encoding();
+ Ok(())
+}
+
+#[derive(Debug, PartialEq)]
+pub enum Value {
+ USize(usize),
+ Bool(bool),
+ F32(f32),
+ U16(u16),
+}
+
+impl std::hash::Hash for Value {
+ fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+ match self {
+ Value::F32(v) => v.to_bits().hash(state),
+ Value::USize(v) => v.hash(state),
+ Value::U16(v) => v.hash(state),
+ Value::Bool(v) => v.hash(state),
+ }
+ }
+}
+
+impl Value {
+ fn data_type(&self) -> MTLDataType {
+ match self {
+ Value::USize(_) => MTLDataType::UInt,
+ Value::F32(_) => MTLDataType::Float,
+ Value::U16(_) => MTLDataType::UShort,
+ Value::Bool(_) => MTLDataType::Bool,
+ }
+ }
+}
+
+/// Not true, good enough for our purposes.
+impl Eq for Value {}
+
+#[derive(Debug, Eq, PartialEq, Hash)]
+struct ConstantValues(Vec<(usize, Value)>);
+
+impl ConstantValues {
+ pub fn new(values: Vec<(usize, Value)>) -> Self {
+ Self(values)
+ }
+
+ fn function_constant_values(&self) -> FunctionConstantValues {
+ let f = FunctionConstantValues::new();
+ for (index, value) in &self.0 {
+ let ty = value.data_type();
+ match value {
+ Value::USize(v) => {
+ f.set_constant_value_at_index(
+ v as *const usize as *const c_void,
+ ty,
+ *index as u64,
+ );
+ }
+ Value::F32(v) => {
+ f.set_constant_value_at_index(
+ v as *const f32 as *const c_void,
+ ty,
+ *index as u64,
+ );
+ }
+ Value::U16(v) => {
+ f.set_constant_value_at_index(
+ v as *const u16 as *const c_void,
+ ty,
+ *index as u64,
+ );
+ }
+ Value::Bool(v) => {
+ f.set_constant_value_at_index(
+ v as *const bool as *const c_void,
+ ty,
+ *index as u64,
+ );
+ }
+ }
+ }
+ f
+ }
+}
+
+#[allow(clippy::too_many_arguments)]
+pub fn call_gemm(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: &'static str,
+ (b, m, n, k): (usize, usize, usize, usize),
+ lhs_stride: &[usize],
+ lhs_offset: usize,
+ lhs_buffer: &Buffer,
+ rhs_stride: &[usize],
+ rhs_offset: usize,
+ rhs_buffer: &Buffer,
+ output: &Buffer,
+) -> Result<(), MetalKernelError> {
+ assert!(rhs_stride.len() >= 2);
+ assert!(lhs_stride.len() >= 2);
+ let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
+ let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
+ let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
+ let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
+ let a_trans = if lhs_m1 == 1 && lhs_m2 == k {
+ false
+ } else if lhs_m1 == m && lhs_m2 == 1 {
+ true
+ } else {
+ return Err(MetalKernelError::MatMulNonContiguous {
+ lhs_stride: lhs_stride.to_vec(),
+ rhs_stride: rhs_stride.to_vec(),
+ mnk: (m, n, k),
+ })?;
+ };
+ let b_trans = if rhs_m1 == 1 && rhs_m2 == n {
+ false
+ } else if rhs_m1 == k && rhs_m2 == 1 {
+ true
+ } else {
+ return Err(MetalKernelError::MatMulNonContiguous {
+ lhs_stride: lhs_stride.to_vec(),
+ rhs_stride: rhs_stride.to_vec(),
+ mnk: (m, n, k),
+ })?;
+ };
+ let d_trans = false;
+ let alpha = 1.0f32;
+ let beta = 0.0f32;
+ let batched = b > 1;
+ let fused_activation = false;
+ let fused_bias = false;
+ let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 {
+ let m_simd = 16;
+ let n_simd = 8;
+ let k_simd = 64;
+ let m_splits = 1;
+ let n_splits = 1;
+ (m_simd, n_simd, k_simd, m_splits, n_splits)
+ } else {
+ let m_simd = 40;
+ let n_simd = 40;
+ let k_simd = 8;
+ let m_splits = 1;
+ let n_splits = 1;
+ (m_simd, n_simd, k_simd, m_splits, n_splits)
+ };
+ let constants = Some(ConstantValues::new(vec![
+ (0, Value::USize(m)),
+ (1, Value::USize(n)),
+ (2, Value::USize(k)),
+ (10, Value::Bool(a_trans)),
+ (11, Value::Bool(b_trans)),
+ (13, Value::Bool(d_trans)),
+ (20, Value::F32(alpha)),
+ (21, Value::F32(beta)),
+ (100, Value::Bool(batched)),
+ (101, Value::Bool(fused_activation)),
+ // Garbage
+ (102, Value::Bool(false)),
+ (103, Value::Bool(false)),
+ (113, Value::Bool(false)),
+ (50_000, Value::Bool(false)),
+ // End garbage
+ (200, Value::U16(m_simd)),
+ (201, Value::U16(n_simd)),
+ (202, Value::U16(k_simd)),
+ (210, Value::U16(m_splits)),
+ (211, Value::U16(n_splits)),
+ (50_001, Value::Bool(fused_bias)),
+ ]));
+ let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?;
+ let m_group = m_simd * m_splits;
+ let n_group = n_simd * n_splits;
+
+ let a_block_length = m_group * k_simd;
+ let b_block_length = k_simd * n_group;
+
+ let mut block_elements = a_block_length + b_block_length;
+ if (m % 8 != 0) && (n % 8 != 0) {
+ let c_block_length = m_group * n_group;
+ block_elements = std::cmp::max(c_block_length, block_elements)
+ }
+ if fused_bias {
+ if d_trans {
+ block_elements = std::cmp::max(block_elements, m_group);
+ } else {
+ block_elements = std::cmp::max(block_elements, n_group);
+ }
+ }
+ let bytes = match name {
+ "sgemm" => 4,
+ "hgemm" => 2,
+ other => {
+ return Err(MetalKernelError::LoadLibraryError(format!(
+ "{other} is not a valid kernel for gemm"
+ )));
+ }
+ };
+ let block_bytes = block_elements * bytes;
+
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
+ encoder.set_compute_pipeline_state(&pipeline);
+ encoder.set_threadgroup_memory_length(0, block_bytes.into());
+ encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
+ encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
+ encoder.set_buffer(2, Some(output), 0);
+ // TODO Tensor D
+
+ let grid_z = b;
+ if batched {
+ let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize;
+ let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize;
+ let byte_stride_c = m * n * bytes as usize;
+ // TODO byte_stride_d
+ let byte_stride_d = 0;
+
+ let mut buffer: Vec<u64> = Vec::with_capacity(b * 4);
+ for i in 0..b {
+ buffer.push((i * byte_stride_a) as u64);
+ buffer.push((i * byte_stride_b) as u64);
+ buffer.push((i * byte_stride_c) as u64);
+ buffer.push((i * byte_stride_d) as u64);
+ }
+ encoder.set_bytes(
+ 10,
+ (buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
+ buffer.as_ptr() as *const NSUInteger as *const c_void,
+ );
+ }
+
+ let grid_size = MTLSize {
+ width: divide(n, n_group.into()),
+ height: divide(m, m_group.into()),
+ depth: grid_z as NSUInteger,
+ };
+ let group_size = MTLSize {
+ width: 32 * (m_splits as u64) * (n_splits as u64),
+ height: 1,
+ depth: 1,
+ };
+ // println!("grid size {grid_size:?} group size {group_size:?}");
+ encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(grid_size, group_size);
+ encoder.update_fence(&kernels.fence);
+ encoder.end_encoding();
+
+ Ok(())
+}
+
+fn divide(m: usize, b: usize) -> NSUInteger {
+ ((m + b - 1) / b) as NSUInteger
+}
+
#[cfg(test)]
mod tests;