diff options
author | Thomas Santerre <thomas@santerre.xyz> | 2024-04-20 18:10:33 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-21 00:10:33 +0200 |
commit | 0067fe00a8477b8c817dcf54d4d4084b07b7fc5b (patch) | |
tree | ea84cb8d6f814224da42c281f96745a8658d24eb /candle-metal-kernels | |
parent | 587ee3bb6fd2b4c2b7bbe7e97751cac96249dd6d (diff) | |
download | candle-0067fe00a8477b8c817dcf54d4d4084b07b7fc5b.tar.gz candle-0067fe00a8477b8c817dcf54d4d4084b07b7fc5b.tar.bz2 candle-0067fe00a8477b8c817dcf54d4d4084b07b7fc5b.zip |
Metal Unary: Add benchmarks and process kernels in a tile based fashion (#2056)
* add basic unary bench for sqrt
* process unary commands in tiles of 4
* re-enable all benchmarks
* rename helper to unary
* modify approach to split up tiled and non-tiled operations
* undo bench ignore for other tests
* update tile size to 2
* only perform the optimization on the contiguous even numbered element case
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 117 | ||||
-rw-r--r-- | candle-metal-kernels/src/unary.metal | 17 |
2 files changed, 97 insertions, 37 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index e05797a2..10f942b4 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -74,6 +74,30 @@ macro_rules! ops{ } } + pub mod contiguous_tiled { + pub struct Kernel(pub &'static str); + $( + pub mod $name { + use super::Kernel; + pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_tiled")); + pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_tiled")); + pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_tiled")); + pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_tiled")); + pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_tiled")); + pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_tiled")); + } + )+ + pub mod copy { + use super::Kernel; + pub const FLOAT: Kernel = Kernel("copy_f32_tiled"); + pub const HALF: Kernel = Kernel("copy_f16_tiled"); + pub const BFLOAT: Kernel = Kernel("copy_bf16_tiled"); + pub const I64: Kernel = Kernel("copy_i64_tiled"); + pub const U32: Kernel = Kernel("copy_u32_tiled"); + pub const U8: Kernel = Kernel("copy_u8_tiled"); + } + } + pub mod strided { pub struct Kernel(pub &'static str); $( @@ -268,30 +292,6 @@ impl Kernels { } #[allow(clippy::too_many_arguments)] -pub fn call_unary_contiguous( - device: &Device, - command_buffer: &CommandBufferRef, - kernels: &Kernels, - kernel_name: unary::contiguous::Kernel, - length: usize, - input: BufferOffset, - output: &Buffer, -) -> Result<(), MetalKernelError> { - let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; - let encoder = command_buffer.new_compute_command_encoder(); - encoder.set_compute_pipeline_state(&pipeline); - - set_params!(encoder, (length, &input, output)); - - let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); - encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); - encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.end_encoding(); - Ok(()) -} - -#[allow(clippy::too_many_arguments)] pub fn call_copy2d( device: &Device, command_buffer: &CommandBufferRef, @@ -335,6 +335,58 @@ pub fn call_copy2d( } #[allow(clippy::too_many_arguments)] +pub fn call_unary_contiguous_tiled( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: unary::contiguous_tiled::Kernel, + length: usize, + input: BufferOffset, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; + let encoder = command_buffer.new_compute_command_encoder(); + let tile_size = 2; + let tiles = length.div_ceil(tile_size); + + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (length, &input, output)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, tiles); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_unary_contiguous( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: unary::contiguous::Kernel, + length: usize, + input: BufferOffset, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; + let encoder = command_buffer.new_compute_command_encoder(); + + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (length, &input, output)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] pub fn call_unary_strided( device: &Device, command_buffer: &CommandBufferRef, @@ -347,16 +399,13 @@ pub fn call_unary_strided( ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?; + let length: usize = shape.iter().product(); let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); - encoder.set_compute_pipeline_state(&pipeline); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); - let length: usize = shape.iter().product(); + encoder.set_compute_pipeline_state(&pipeline); set_params!(encoder, (length, num_dims, shape, strides, &input, &output)); - - let width: usize = shape.iter().product(); - let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); - encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); @@ -410,10 +459,10 @@ pub fn call_binary_strided( let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); let width: usize = shape.iter().product(); - encoder.set_compute_pipeline_state(&pipeline); - let length: usize = shape.iter().product(); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); + encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, ( @@ -427,14 +476,12 @@ pub fn call_binary_strided( output ) ); - - let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); - encoder.use_resource(left_input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(right_input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); + Ok(()) } diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index ec793eae..143e9500 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -68,6 +68,8 @@ template <typename T> METAL_FUNC T silu(T in){ return in / (static_cast<T>(1) + exp(-in)); } +#define TILE_SIZE 2 + #define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \ kernel void FN_NAME( \ constant size_t &dim, \ @@ -79,8 +81,8 @@ kernel void FN_NAME( \ return; \ } \ output[tid] = TYPENAME(FN(float(input[tid]))); \ -}\ -kernel void FN_NAME_STRIDED( \ +} \ +kernel void FN_NAME##_##strided( \ constant size_t &dim, \ constant size_t &num_dims, \ constant size_t *dims, \ @@ -93,6 +95,17 @@ kernel void FN_NAME_STRIDED( \ return; \ } \ output[tid] = TYPENAME(FN(float(input[get_strided_index(tid, num_dims, dims, strides)]))); \ +} \ +kernel void FN_NAME##_##tiled( \ + constant size_t &dim, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + for (uint i = 0; i < TILE_SIZE; i++) { \ + const uint idx = tid * TILE_SIZE + i; \ + output[idx] = TYPENAME(FN(float(input[idx]))); \ + } \ } #define UNARY_OP(NAME) \ |