summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorThomas Santerre <thomas@santerre.xyz>2024-04-20 18:10:33 -0400
committerGitHub <noreply@github.com>2024-04-21 00:10:33 +0200
commit0067fe00a8477b8c817dcf54d4d4084b07b7fc5b (patch)
treeea84cb8d6f814224da42c281f96745a8658d24eb /candle-metal-kernels
parent587ee3bb6fd2b4c2b7bbe7e97751cac96249dd6d (diff)
downloadcandle-0067fe00a8477b8c817dcf54d4d4084b07b7fc5b.tar.gz
candle-0067fe00a8477b8c817dcf54d4d4084b07b7fc5b.tar.bz2
candle-0067fe00a8477b8c817dcf54d4d4084b07b7fc5b.zip
Metal Unary: Add benchmarks and process kernels in a tile based fashion (#2056)
* add basic unary bench for sqrt * process unary commands in tiles of 4 * re-enable all benchmarks * rename helper to unary * modify approach to split up tiled and non-tiled operations * undo bench ignore for other tests * update tile size to 2 * only perform the optimization on the contiguous even numbered element case
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/lib.rs117
-rw-r--r--candle-metal-kernels/src/unary.metal17
2 files changed, 97 insertions, 37 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index e05797a2..10f942b4 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -74,6 +74,30 @@ macro_rules! ops{
}
}
+ pub mod contiguous_tiled {
+ pub struct Kernel(pub &'static str);
+ $(
+ pub mod $name {
+ use super::Kernel;
+ pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_tiled"));
+ pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_tiled"));
+ pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_tiled"));
+ pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_tiled"));
+ pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_tiled"));
+ pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_tiled"));
+ }
+ )+
+ pub mod copy {
+ use super::Kernel;
+ pub const FLOAT: Kernel = Kernel("copy_f32_tiled");
+ pub const HALF: Kernel = Kernel("copy_f16_tiled");
+ pub const BFLOAT: Kernel = Kernel("copy_bf16_tiled");
+ pub const I64: Kernel = Kernel("copy_i64_tiled");
+ pub const U32: Kernel = Kernel("copy_u32_tiled");
+ pub const U8: Kernel = Kernel("copy_u8_tiled");
+ }
+ }
+
pub mod strided {
pub struct Kernel(pub &'static str);
$(
@@ -268,30 +292,6 @@ impl Kernels {
}
#[allow(clippy::too_many_arguments)]
-pub fn call_unary_contiguous(
- device: &Device,
- command_buffer: &CommandBufferRef,
- kernels: &Kernels,
- kernel_name: unary::contiguous::Kernel,
- length: usize,
- input: BufferOffset,
- output: &Buffer,
-) -> Result<(), MetalKernelError> {
- let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
- let encoder = command_buffer.new_compute_command_encoder();
- encoder.set_compute_pipeline_state(&pipeline);
-
- set_params!(encoder, (length, &input, output));
-
- let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
- encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
- encoder.use_resource(output, metal::MTLResourceUsage::Write);
- encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
- Ok(())
-}
-
-#[allow(clippy::too_many_arguments)]
pub fn call_copy2d(
device: &Device,
command_buffer: &CommandBufferRef,
@@ -335,6 +335,58 @@ pub fn call_copy2d(
}
#[allow(clippy::too_many_arguments)]
+pub fn call_unary_contiguous_tiled(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ kernel_name: unary::contiguous_tiled::Kernel,
+ length: usize,
+ input: BufferOffset,
+ output: &Buffer,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
+ let encoder = command_buffer.new_compute_command_encoder();
+ let tile_size = 2;
+ let tiles = length.div_ceil(tile_size);
+
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(encoder, (length, &input, output));
+
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, tiles);
+ encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.end_encoding();
+ Ok(())
+}
+
+#[allow(clippy::too_many_arguments)]
+pub fn call_unary_contiguous(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ kernel_name: unary::contiguous::Kernel,
+ length: usize,
+ input: BufferOffset,
+ output: &Buffer,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
+ let encoder = command_buffer.new_compute_command_encoder();
+
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(encoder, (length, &input, output));
+
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
+ encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.end_encoding();
+ Ok(())
+}
+
+#[allow(clippy::too_many_arguments)]
pub fn call_unary_strided(
device: &Device,
command_buffer: &CommandBufferRef,
@@ -347,16 +399,13 @@ pub fn call_unary_strided(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
+ let length: usize = shape.iter().product();
let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder();
- encoder.set_compute_pipeline_state(&pipeline);
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
- let length: usize = shape.iter().product();
+ encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, num_dims, shape, strides, &input, &output));
-
- let width: usize = shape.iter().product();
- let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
-
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
@@ -410,10 +459,10 @@ pub fn call_binary_strided(
let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder();
let width: usize = shape.iter().product();
- encoder.set_compute_pipeline_state(&pipeline);
-
let length: usize = shape.iter().product();
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
+ encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
@@ -427,14 +476,12 @@ pub fn call_binary_strided(
output
)
);
-
- let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
-
encoder.use_resource(left_input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(right_input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
+
Ok(())
}
diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal
index ec793eae..143e9500 100644
--- a/candle-metal-kernels/src/unary.metal
+++ b/candle-metal-kernels/src/unary.metal
@@ -68,6 +68,8 @@ template <typename T> METAL_FUNC T silu(T in){
return in / (static_cast<T>(1) + exp(-in));
}
+#define TILE_SIZE 2
+
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \
constant size_t &dim, \
@@ -79,8 +81,8 @@ kernel void FN_NAME( \
return; \
} \
output[tid] = TYPENAME(FN(float(input[tid]))); \
-}\
-kernel void FN_NAME_STRIDED( \
+} \
+kernel void FN_NAME##_##strided( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
@@ -93,6 +95,17 @@ kernel void FN_NAME_STRIDED( \
return; \
} \
output[tid] = TYPENAME(FN(float(input[get_strided_index(tid, num_dims, dims, strides)]))); \
+} \
+kernel void FN_NAME##_##tiled( \
+ constant size_t &dim, \
+ device const TYPENAME *input, \
+ device TYPENAME *output, \
+ uint tid [[ thread_position_in_grid ]] \
+) { \
+ for (uint i = 0; i < TILE_SIZE; i++) { \
+ const uint idx = tid * TILE_SIZE + i; \
+ output[idx] = TYPENAME(FN(float(input[idx]))); \
+ } \
}
#define UNARY_OP(NAME) \