diff options
Diffstat (limited to 'candle-metal-kernels/src/lib.rs')
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 1045 |
1 files changed, 898 insertions, 147 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 5a6bd41b..0418c96c 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1,6 +1,6 @@ use metal::{ - Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor, - ComputePipelineState, Device, Function, Library, MTLSize, + Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, + Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger, }; use std::collections::HashMap; use std::ffi::c_void; @@ -13,7 +13,12 @@ const BINARY: &str = include_str!("binary.metal"); const TERNARY: &str = include_str!("ternary.metal"); const CAST: &str = include_str!("cast.metal"); const REDUCE: &str = include_str!("reduce.metal"); +const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); +/// Most kernels apply similarly across the tensors +/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the +/// actual total buffer length). +/// Then kernels can just do their op on their single point in the buffer. fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) { let size = length as u64; let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size); @@ -35,6 +40,10 @@ fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTL fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) { <P as EncoderParam>::set_param(encoder, position, data) } + +/// Helper functions to create the various objects on the compute command encoder +/// on a single line. +/// Prevents getting wrong some arguments number and mixing length and size in bytes. trait EncoderParam { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self); } @@ -59,8 +68,8 @@ impl<T> EncoderParam for &[T] { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { encoder.set_bytes( position, - (core::mem::size_of::<T>() * data.len()) as u64, - data.as_ptr() as *const T as *const c_void, + core::mem::size_of_val(data) as u64, + data.as_ptr() as *const c_void, ); } } @@ -105,54 +114,59 @@ pub enum Source { Ternary, Cast, Reduce, + Mfa, } macro_rules! ops{ ($($name:ident),+) => { pub mod contiguous { - #[derive(Clone, Copy)] - pub struct Kernel(pub(crate) &'static str); - impl std::fmt::Display for Kernel { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } - } + pub struct Kernel(pub &'static str); $( pub mod $name { use super::Kernel; - pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float")); - pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half")); - pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat")); + pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32")); + pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16")); + pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16")); } )+ + pub mod copy { + use super::Kernel; + pub const FLOAT: Kernel = Kernel("copy_f32"); + pub const HALF: Kernel = Kernel("copy_f16"); + pub const BFLOAT: Kernel = Kernel("copy_bf16"); + pub const U32: Kernel = Kernel("copy_u32"); + pub const U8: Kernel = Kernel("copy_u8"); + } } pub mod strided { - #[derive(Clone, Copy)] - pub struct Kernel(pub(crate) &'static str); - impl std::fmt::Display for Kernel { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } - } + pub struct Kernel(pub &'static str); $( pub mod $name { use super::Kernel; - pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float_strided")); - pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half_strided")); - pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided")); + pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_strided")); + pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided")); + pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided")); } )+ + pub mod copy { + use super::Kernel; + pub const FLOAT: Kernel = Kernel("copy_f32_strided"); + pub const HALF: Kernel = Kernel("copy_f16_strided"); + pub const BFLOAT: Kernel = Kernel("copy_bf16_strided"); + pub const U32: Kernel = Kernel("copy_u32_strided"); + pub const U8: Kernel = Kernel("copy_u8_strided"); + } } }; } pub mod unary { - ops!(cos, sin, exp, sqr, sqrt, neg, copy, log); + ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh); } pub mod binary { - ops!(add, sub, mul, div); + ops!(add, sub, mul, div, min, max, eq, ne, le, lt, ge, gt); } #[derive(thiserror::Error, Debug)] @@ -161,8 +175,18 @@ pub enum MetalKernelError { LockError(String), #[error("Error while loading library: {0}")] LoadLibraryError(String), - #[error("Error while loading function: {0}")] + #[error("Error while loading function: {0:?}")] LoadFunctionError(String), + #[error("Failed to create compute function")] + FailedToCreateComputeFunction, + #[error("Failed to create pipeline")] + FailedToCreatePipeline(String), + #[error("Invalid matmul arguments {lhs_stride:?} {rhs_stride:?} {mnk:?}")] + MatMulNonContiguous { + lhs_stride: Vec<usize>, + rhs_stride: Vec<usize>, + mnk: (usize, usize, usize), + }, } impl<T> From<std::sync::PoisonError<T>> for MetalKernelError { @@ -171,21 +195,25 @@ impl<T> From<std::sync::PoisonError<T>> for MetalKernelError { } } -type KernelMap<T> = HashMap<&'static str, T>; type Libraries = HashMap<Source, Library>; -type Functions = KernelMap<Function>; +type Pipelines = HashMap<(&'static str, Option<ConstantValues>), ComputePipelineState>; -#[derive(Debug, Default)] +#[derive(Debug)] pub struct Kernels { libraries: RwLock<Libraries>, - funcs: RwLock<Functions>, + pipelines: RwLock<Pipelines>, + fence: metal::Fence, } impl Kernels { - pub fn new() -> Self { + pub fn new(fence: metal::Fence) -> Self { let libraries = RwLock::new(Libraries::new()); - let funcs = RwLock::new(Functions::new()); - Self { libraries, funcs } + let pipelines = RwLock::new(Pipelines::new()); + Self { + libraries, + pipelines, + fence, + } } fn get_library_source(&self, source: Source) -> &'static str { @@ -197,9 +225,12 @@ impl Kernels { Source::Indexing => INDEXING, Source::Cast => CAST, Source::Reduce => REDUCE, + Source::Mfa => panic!("Invalid lib"), } } + /// Load the give library from its [`source`]. + /// If this has been previously loaded it will just fetch it from cache. pub fn load_library( &self, device: &Device, @@ -209,33 +240,83 @@ impl Kernels { if let Some(lib) = libraries.get(&source) { Ok(lib.clone()) } else { - let source_content = self.get_library_source(source); - let lib = device - .new_library_with_source(source_content, &CompileOptions::new()) - .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?; + let lib = match source { + Source::Mfa => { + let source_data = MFA; + device.new_library_with_data(source_data).map_err(|e| { + MetalKernelError::LoadLibraryError(format!( + "Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}" + )) + })? + } + source => { + let source_content = self.get_library_source(source); + device + .new_library_with_source(source_content, &CompileOptions::new()) + .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))? + } + }; libraries.insert(source, lib.clone()); Ok(lib) } } - pub fn load_function( + fn load_function( &self, device: &Device, source: Source, name: &'static str, + constants: Option<FunctionConstantValues>, ) -> Result<Function, MetalKernelError> { - let mut funcs = self.funcs.write()?; - if let Some(func) = funcs.get(name) { - Ok(func.clone()) + let func = self + .load_library(device, source)? + .get_function(name, constants) + .map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?; + Ok(func) + } + + /// Load the give pipeline + /// loads the library from source, then gets the function [`name`] from + /// that source + fn load_pipeline_with_constants( + &self, + device: &Device, + source: Source, + name: &'static str, + constants: Option<ConstantValues>, + ) -> Result<ComputePipelineState, MetalKernelError> { + let mut pipelines = self.pipelines.write()?; + let key = (name, constants); + if let Some(pipeline) = pipelines.get(&key) { + Ok(pipeline.clone()) } else { - let func = self - .load_library(device, source)? - .get_function(name, None) - .map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?; - funcs.insert(name, func.clone()); - Ok(func) + let (name, constants) = key; + let func = self.load_function( + device, + source, + name, + constants.as_ref().map(|c| c.function_constant_values()), + )?; + let pipeline = device + .new_compute_pipeline_state_with_function(&func) + .map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?; + pipelines.insert((name, constants), pipeline.clone()); + + Ok(pipeline) } } + + /// Load the give pipeline + /// loads the library from source, then gets the function [`name`] from + /// that source (without constants) + pub fn load_pipeline( + &self, + device: &Device, + source: Source, + name: &'static str, + ) -> Result<ComputePipelineState, MetalKernelError> { + self.load_pipeline_with_constants(device, source, name, None) + } } #[allow(clippy::too_many_arguments)] @@ -246,25 +327,20 @@ pub fn call_unary_contiguous( kernel_name: unary::contiguous::Kernel, length: usize, input: &Buffer, - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Unary, kernel_name.0)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); - + let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -279,21 +355,14 @@ pub fn call_unary_strided( input: &Buffer, strides: &[usize], offset: usize, - output: &mut Buffer, + output: &Buffer, output_offset: usize, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Unary, name.0)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?; let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); @@ -312,7 +381,10 @@ pub fn call_unary_strided( let width: usize = shape.iter().product(); let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -326,26 +398,23 @@ pub fn call_binary_contiguous( length: usize, left: &Buffer, right: &Buffer, - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Binary, kernel_name.0)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, left, right, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + encoder.use_resource(left, metal::MTLResourceUsage::Read); + encoder.use_resource(right, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -363,21 +432,14 @@ pub fn call_binary_strided( right_input: &Buffer, right_strides: &[usize], right_offset: usize, - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Binary, name.0)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?; let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); let width: usize = shape.iter().product(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); @@ -398,7 +460,11 @@ pub fn call_binary_strided( let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); + encoder.use_resource(left_input, metal::MTLResourceUsage::Read); + encoder.use_resource(right_input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -411,31 +477,68 @@ pub fn call_cast_contiguous( kernel_name: &'static str, length: usize, input: &Buffer, - output: &mut Buffer, + input_offset: usize, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Cast, kernel_name)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); + let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (length, (input, input_offset), output)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_cast_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: &'static str, + shape: &[usize], + input: &Buffer, + input_strides: &[usize], + input_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, input, output)); + let length: usize = shape.iter().product(); + + set_params!( + encoder, + ( + length, + shape.len(), + shape, + input_strides, + (input, input_offset), + output + ) + ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } -#[allow(clippy::too_many_arguments)] pub fn call_reduce_contiguous( device: &Device, command_buffer: &CommandBufferRef, @@ -444,24 +547,78 @@ pub fn call_reduce_contiguous( length: usize, out_length: usize, input: &Buffer, - output: &mut Buffer, + input_offset: usize, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Reduce, kernel_name)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; + let elements_to_sum = length / out_length; - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + (length, elements_to_sum, (input, input_offset), output) + ); + + let thread_group_count = MTLSize { + width: out_length as u64, + height: 1, + depth: 1, + }; + + let width = std::cmp::min( + pipeline.max_total_threads_per_threadgroup(), + (elements_to_sum as u64 + 2 - 1) / 2, + ) + .next_power_of_two(); + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + Ok(()) +} + +pub fn call_reduce_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: &'static str, + shape: &[usize], + strides: &[usize], + out_length: usize, + input: &Buffer, + input_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let length: usize = shape.iter().product(); + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let elements_to_sum = length / out_length; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, elements_to_sum, input, output)); + set_params!( + encoder, + ( + shape.len(), + shape, + strides, + elements_to_sum, + (input, input_offset), + output + ) + ); let thread_group_count = MTLSize { width: out_length as u64, @@ -471,7 +628,7 @@ pub fn call_reduce_contiguous( let width = std::cmp::min( pipeline.max_total_threads_per_threadgroup(), - (elements_to_sum as u64 + 2 - 1) / 2, + elements_to_sum as u64, ) .next_power_of_two(); @@ -481,7 +638,10 @@ pub fn call_reduce_contiguous( depth: 1, }; + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -495,22 +655,18 @@ pub fn call_last_softmax( length: usize, elements_to_sum: usize, input: &Buffer, - output: &mut Buffer, + input_offset: usize, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Reduce, kernel_name)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); - + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, elements_to_sum, input, output)); + set_params!( + encoder, + (length, elements_to_sum, (input, input_offset), output) + ); let out_length = length / elements_to_sum; @@ -532,7 +688,10 @@ pub fn call_last_softmax( depth: 1, }; + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -542,34 +701,214 @@ pub fn call_affine( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, + name: &'static str, size: usize, input: &Buffer, - output: &mut Buffer, + output: &Buffer, + mul: f32, + add: f32, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (size, mul, add, input, output)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_affine_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + input: &Buffer, + input_stride: &[usize], + input_offset: usize, + output: &Buffer, mul: f32, add: f32, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Affine, "affine_float")?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; + let size: usize = shape.iter().product(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), + set_params!( + encoder, + ( + size, + shape.len(), + shape, + input_stride, + mul, + add, + (input, input_offset), + output ) - .unwrap(); + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_powf( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + size: usize, + input: &Buffer, + output: &Buffer, + mul: f32, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (size, mul, add, input, output)); + set_params!(encoder, (size, mul, input, output)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_powf_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + input: &Buffer, + input_stride: &[usize], + input_offset: usize, + output: &Buffer, + mul: f32, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; + let size: usize = shape.iter().product(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + size, + shape.len(), + shape, + input_stride, + mul, + (input, input_offset), + output + ) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_elu( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + size: usize, + input: &Buffer, + output: &Buffer, + mul: f32, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (size, mul, input, output)); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } #[allow(clippy::too_many_arguments)] +pub fn call_elu_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + input: &Buffer, + input_stride: &[usize], + input_offset: usize, + output: &Buffer, + mul: f32, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; + let size: usize = shape.iter().product(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + size, + shape.len(), + shape, + input_stride, + mul, + (input, input_offset), + output + ) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + Ok(()) +} + pub fn call_where_cond_strided( device: &Device, command_buffer: &CommandBufferRef, @@ -582,19 +921,12 @@ pub fn call_where_cond_strided( (left_stride, left_offset): (&[usize], usize), right: &Buffer, (right_stride, right_offset): (&[usize], usize), - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Ternary, name)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); let size: usize = shape.iter().product(); @@ -618,7 +950,12 @@ pub fn call_where_cond_strided( let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.use_resource(cond, metal::MTLResourceUsage::Read); + encoder.use_resource(left, metal::MTLResourceUsage::Read); + encoder.use_resource(right, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -634,20 +971,18 @@ pub fn call_index_select( dim: usize, input: &Buffer, ids: &Buffer, - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { let left_size: usize = shape[..dim].iter().product(); let right_size: usize = shape[dim + 1..].iter().product(); let src_dim_size = shape[dim]; let dst_el = ids_size * left_size * right_size; - let func = kernels.load_function(device, Source::Indexing, name)?; - let pipeline = device - .new_compute_pipeline_state_with_function(&func) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -666,10 +1001,426 @@ pub fn call_index_select( let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(ids, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } +#[allow(clippy::too_many_arguments)] +pub fn call_gather( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + ids_size: usize, + dim: usize, + input: &Buffer, + input_offset: usize, + ids: &Buffer, + ids_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let left_size: usize = shape[..dim].iter().product(); + let right_size: usize = shape[dim + 1..].iter().product(); + let src_dim_size = shape[dim]; + let dst_el = ids_size * left_size * right_size; + + let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; + + let encoder = command_buffer.new_compute_command_encoder(); + + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + dst_el, + left_size, + src_dim_size, + right_size, + ids_size, + (input, input_offset), + (ids, ids_offset), + output + ) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(ids, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + Ok(()) +} + +pub fn call_scatter_add( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + src_shape: &[usize], + dst_shape: &[usize], + dim: usize, + input: &Buffer, + input_offset: usize, + ids: &Buffer, + ids_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let left_size: usize = src_shape[..dim].iter().product(); + let right_size: usize = src_shape[dim + 1..].iter().product(); + let src_dim_size = src_shape[dim]; + let dst_el = left_size * right_size; + let dst_dim_size = dst_shape[dim]; + + let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; + + let encoder = command_buffer.new_compute_command_encoder(); + + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + dst_el, + left_size, + src_dim_size, + right_size, + dst_dim_size, + (input, input_offset), + (ids, ids_offset), + output + ) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(ids, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + Ok(()) +} + +pub fn call_index_add( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + src_shape: &[usize], + dst_shape: &[usize], + ids_shape: &[usize], + dim: usize, + input: &Buffer, + input_offset: usize, + ids: &Buffer, + ids_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let left_size: usize = src_shape[..dim].iter().product(); + let right_size: usize = src_shape[dim + 1..].iter().product(); + let src_dim_size = src_shape[dim]; + let dst_el = left_size * right_size; + let dst_dim_size = dst_shape[dim]; + let ids_dim_size = ids_shape[0]; + + let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; + let encoder = command_buffer.new_compute_command_encoder(); + + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + dst_el, + left_size, + src_dim_size, + right_size, + dst_dim_size, + ids_dim_size, + (input, input_offset), + (ids, ids_offset), + output + ) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(ids, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + Ok(()) +} + +#[derive(Debug, PartialEq)] +pub enum Value { + USize(usize), + Bool(bool), + F32(f32), + U16(u16), +} + +impl std::hash::Hash for Value { + fn hash<H: std::hash::Hasher>(&self, state: &mut H) { + match self { + Value::F32(v) => v.to_bits().hash(state), + Value::USize(v) => v.hash(state), + Value::U16(v) => v.hash(state), + Value::Bool(v) => v.hash(state), + } + } +} + +impl Value { + fn data_type(&self) -> MTLDataType { + match self { + Value::USize(_) => MTLDataType::UInt, + Value::F32(_) => MTLDataType::Float, + Value::U16(_) => MTLDataType::UShort, + Value::Bool(_) => MTLDataType::Bool, + } + } +} + +/// Not true, good enough for our purposes. +impl Eq for Value {} + +#[derive(Debug, Eq, PartialEq, Hash)] +struct ConstantValues(Vec<(usize, Value)>); + +impl ConstantValues { + pub fn new(values: Vec<(usize, Value)>) -> Self { + Self(values) + } + + fn function_constant_values(&self) -> FunctionConstantValues { + let f = FunctionConstantValues::new(); + for (index, value) in &self.0 { + let ty = value.data_type(); + match value { + Value::USize(v) => { + f.set_constant_value_at_index( + v as *const usize as *const c_void, + ty, + *index as u64, + ); + } + Value::F32(v) => { + f.set_constant_value_at_index( + v as *const f32 as *const c_void, + ty, + *index as u64, + ); + } + Value::U16(v) => { + f.set_constant_value_at_index( + v as *const u16 as *const c_void, + ty, + *index as u64, + ); + } + Value::Bool(v) => { + f.set_constant_value_at_index( + v as *const bool as *const c_void, + ty, + *index as u64, + ); + } + } + } + f + } +} + +#[allow(clippy::too_many_arguments)] +pub fn call_gemm( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + (b, m, n, k): (usize, usize, usize, usize), + lhs_stride: &[usize], + lhs_offset: usize, + lhs_buffer: &Buffer, + rhs_stride: &[usize], + rhs_offset: usize, + rhs_buffer: &Buffer, + output: &Buffer, +) -> Result<(), MetalKernelError> { + assert!(rhs_stride.len() >= 2); + assert!(lhs_stride.len() >= 2); + let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; + let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; + let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; + let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; + let a_trans = if lhs_m1 == 1 && lhs_m2 == k { + false + } else if lhs_m1 == m && lhs_m2 == 1 { + true + } else { + return Err(MetalKernelError::MatMulNonContiguous { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + mnk: (m, n, k), + })?; + }; + let b_trans = if rhs_m1 == 1 && rhs_m2 == n { + false + } else if rhs_m1 == k && rhs_m2 == 1 { + true + } else { + return Err(MetalKernelError::MatMulNonContiguous { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + mnk: (m, n, k), + })?; + }; + let d_trans = false; + let alpha = 1.0f32; + let beta = 0.0f32; + let batched = b > 1; + let fused_activation = false; + let fused_bias = false; + let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 { + let m_simd = 16; + let n_simd = 8; + let k_simd = 64; + let m_splits = 1; + let n_splits = 1; + (m_simd, n_simd, k_simd, m_splits, n_splits) + } else { + let m_simd = 40; + let n_simd = 40; + let k_simd = 8; + let m_splits = 1; + let n_splits = 1; + (m_simd, n_simd, k_simd, m_splits, n_splits) + }; + let constants = Some(ConstantValues::new(vec![ + (0, Value::USize(m)), + (1, Value::USize(n)), + (2, Value::USize(k)), + (10, Value::Bool(a_trans)), + (11, Value::Bool(b_trans)), + (13, Value::Bool(d_trans)), + (20, Value::F32(alpha)), + (21, Value::F32(beta)), + (100, Value::Bool(batched)), + (101, Value::Bool(fused_activation)), + // Garbage + (102, Value::Bool(false)), + (103, Value::Bool(false)), + (113, Value::Bool(false)), + (50_000, Value::Bool(false)), + // End garbage + (200, Value::U16(m_simd)), + (201, Value::U16(n_simd)), + (202, Value::U16(k_simd)), + (210, Value::U16(m_splits)), + (211, Value::U16(n_splits)), + (50_001, Value::Bool(fused_bias)), + ])); + let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?; + let m_group = m_simd * m_splits; + let n_group = n_simd * n_splits; + + let a_block_length = m_group * k_simd; + let b_block_length = k_simd * n_group; + + let mut block_elements = a_block_length + b_block_length; + if (m % 8 != 0) && (n % 8 != 0) { + let c_block_length = m_group * n_group; + block_elements = std::cmp::max(c_block_length, block_elements) + } + if fused_bias { + if d_trans { + block_elements = std::cmp::max(block_elements, m_group); + } else { + block_elements = std::cmp::max(block_elements, n_group); + } + } + let bytes = match name { + "sgemm" => 4, + "hgemm" => 2, + other => { + return Err(MetalKernelError::LoadLibraryError(format!( + "{other} is not a valid kernel for gemm" + ))); + } + }; + let block_bytes = block_elements * bytes; + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + encoder.set_threadgroup_memory_length(0, block_bytes.into()); + encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger); + encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger); + encoder.set_buffer(2, Some(output), 0); + // TODO Tensor D + + let grid_z = b; + if batched { + let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize; + let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize; + let byte_stride_c = m * n * bytes as usize; + // TODO byte_stride_d + let byte_stride_d = 0; + + let mut buffer: Vec<u64> = Vec::with_capacity(b * 4); + for i in 0..b { + buffer.push((i * byte_stride_a) as u64); + buffer.push((i * byte_stride_b) as u64); + buffer.push((i * byte_stride_c) as u64); + buffer.push((i * byte_stride_d) as u64); + } + encoder.set_bytes( + 10, + (buffer.len() * core::mem::size_of::<u64>()) as NSUInteger, + buffer.as_ptr() as *const NSUInteger as *const c_void, + ); + } + + let grid_size = MTLSize { + width: divide(n, n_group.into()), + height: divide(m, m_group.into()), + depth: grid_z as NSUInteger, + }; + let group_size = MTLSize { + width: 32 * (m_splits as u64) * (n_splits as u64), + height: 1, + depth: 1, + }; + // println!("grid size {grid_size:?} group size {group_size:?}"); + encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(grid_size, group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + + Ok(()) +} + +fn divide(m: usize, b: usize) -> NSUInteger { + ((m + b - 1) / b) as NSUInteger +} + #[cfg(test)] mod tests; |