summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-metal-kernels/src/lib.rs')
-rw-r--r--candle-metal-kernels/src/lib.rs117
1 files changed, 82 insertions, 35 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index e05797a2..10f942b4 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -74,6 +74,30 @@ macro_rules! ops{
}
}
+ pub mod contiguous_tiled {
+ pub struct Kernel(pub &'static str);
+ $(
+ pub mod $name {
+ use super::Kernel;
+ pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_tiled"));
+ pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_tiled"));
+ pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_tiled"));
+ pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_tiled"));
+ pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_tiled"));
+ pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_tiled"));
+ }
+ )+
+ pub mod copy {
+ use super::Kernel;
+ pub const FLOAT: Kernel = Kernel("copy_f32_tiled");
+ pub const HALF: Kernel = Kernel("copy_f16_tiled");
+ pub const BFLOAT: Kernel = Kernel("copy_bf16_tiled");
+ pub const I64: Kernel = Kernel("copy_i64_tiled");
+ pub const U32: Kernel = Kernel("copy_u32_tiled");
+ pub const U8: Kernel = Kernel("copy_u8_tiled");
+ }
+ }
+
pub mod strided {
pub struct Kernel(pub &'static str);
$(
@@ -268,30 +292,6 @@ impl Kernels {
}
#[allow(clippy::too_many_arguments)]
-pub fn call_unary_contiguous(
- device: &Device,
- command_buffer: &CommandBufferRef,
- kernels: &Kernels,
- kernel_name: unary::contiguous::Kernel,
- length: usize,
- input: BufferOffset,
- output: &Buffer,
-) -> Result<(), MetalKernelError> {
- let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
- let encoder = command_buffer.new_compute_command_encoder();
- encoder.set_compute_pipeline_state(&pipeline);
-
- set_params!(encoder, (length, &input, output));
-
- let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
- encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
- encoder.use_resource(output, metal::MTLResourceUsage::Write);
- encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
- encoder.end_encoding();
- Ok(())
-}
-
-#[allow(clippy::too_many_arguments)]
pub fn call_copy2d(
device: &Device,
command_buffer: &CommandBufferRef,
@@ -335,6 +335,58 @@ pub fn call_copy2d(
}
#[allow(clippy::too_many_arguments)]
+pub fn call_unary_contiguous_tiled(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ kernel_name: unary::contiguous_tiled::Kernel,
+ length: usize,
+ input: BufferOffset,
+ output: &Buffer,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
+ let encoder = command_buffer.new_compute_command_encoder();
+ let tile_size = 2;
+ let tiles = length.div_ceil(tile_size);
+
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(encoder, (length, &input, output));
+
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, tiles);
+ encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.end_encoding();
+ Ok(())
+}
+
+#[allow(clippy::too_many_arguments)]
+pub fn call_unary_contiguous(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ kernel_name: unary::contiguous::Kernel,
+ length: usize,
+ input: BufferOffset,
+ output: &Buffer,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
+ let encoder = command_buffer.new_compute_command_encoder();
+
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(encoder, (length, &input, output));
+
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
+ encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.end_encoding();
+ Ok(())
+}
+
+#[allow(clippy::too_many_arguments)]
pub fn call_unary_strided(
device: &Device,
command_buffer: &CommandBufferRef,
@@ -347,16 +399,13 @@ pub fn call_unary_strided(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
+ let length: usize = shape.iter().product();
let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder();
- encoder.set_compute_pipeline_state(&pipeline);
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
- let length: usize = shape.iter().product();
+ encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, num_dims, shape, strides, &input, &output));
-
- let width: usize = shape.iter().product();
- let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
-
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
@@ -410,10 +459,10 @@ pub fn call_binary_strided(
let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder();
let width: usize = shape.iter().product();
- encoder.set_compute_pipeline_state(&pipeline);
-
let length: usize = shape.iter().product();
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
+ encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
@@ -427,14 +476,12 @@ pub fn call_binary_strided(
output
)
);
-
- let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
-
encoder.use_resource(left_input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(right_input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
+
Ok(())
}