summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-07 12:34:16 +0200
committerGitHub <noreply@github.com>2024-04-07 12:34:16 +0200
commit7f354473cf495db4554e08f84be44ed498f1aa5e (patch)
tree494e7edc590b754d17cd6da4608edcb24d9dd239 /candle-metal-kernels
parent33c9b6655459bd1086574cef9ba8f2e72a8804c8 (diff)
downloadcandle-7f354473cf495db4554e08f84be44ed498f1aa5e.tar.gz
candle-7f354473cf495db4554e08f84be44ed498f1aa5e.tar.bz2
candle-7f354473cf495db4554e08f84be44ed498f1aa5e.zip
Optimize copy-2d for metal. (#2024)
* Optimize copy-2d for metal. * Add a hacky stopping rule for moondream.
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/lib.rs57
-rw-r--r--candle-metal-kernels/src/unary.metal20
2 files changed, 57 insertions, 20 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 4cff9bda..8b9be670 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -40,6 +40,44 @@ fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTL
(thread_group_count, thread_group_size)
}
+// https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96
+fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
+ let mut pows0 = 0u64;
+ let mut pows1 = 0u64;
+ let mut pows2 = 0u64;
+ let mut sum = 0u64;
+ loop {
+ let presum = sum;
+ // Check all the pows
+ if dim0 >= (1 << (pows0 + 1)) {
+ pows0 += 1;
+ sum += 1;
+ }
+ if sum == 10 {
+ break;
+ }
+ if dim1 >= (1 << (pows1 + 1)) {
+ pows1 += 1;
+ sum += 1;
+ }
+ if sum == 10 {
+ break;
+ }
+ if dim2 >= (1 << (pows2 + 1)) {
+ pows2 += 1;
+ sum += 1;
+ }
+ if sum == presum || sum == 10 {
+ break;
+ }
+ }
+ MTLSize {
+ width: 1 << pows0,
+ height: 1 << pows1,
+ depth: 1 << pows2,
+ }
+}
+
fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {
<P as EncoderParam>::set_param(encoder, position, data)
}
@@ -396,21 +434,24 @@ pub fn call_copy2d(
set_params!(
encoder,
(
- d1,
- d2,
- src_s,
- dst_s,
+ d1 as i64,
+ d2 as i64,
+ src_s as i64,
+ dst_s as i64,
(input, src_o_in_bytes),
(output, dst_o_in_bytes)
)
);
- let width: usize = d1 * d2;
- let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
-
+ let grid_dims = MTLSize {
+ width: d1 as u64,
+ height: d2 as u64,
+ depth: 1,
+ };
+ let group_dims = get_block_dims(d1 as u64, d2 as u64, 1);
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
- encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.dispatch_threads(grid_dims, group_dims);
encoder.end_encoding();
Ok(())
}
diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal
index 809522d7..4b6363ed 100644
--- a/candle-metal-kernels/src/unary.metal
+++ b/candle-metal-kernels/src/unary.metal
@@ -104,21 +104,17 @@ UNARY(NAME, bfloat, NAME##_bf16, NAME##_bf16_strided);
#define COPY2D(FN_NAME, TYPENAME) \
kernel void FN_NAME( \
- constant size_t &d1, \
- constant size_t &d2, \
- constant size_t &src_s, \
- constant size_t &dst_s, \
+ constant int64_t &d1, \
+ constant int64_t &d2, \
+ constant int64_t &src_s, \
+ constant int64_t &dst_s, \
device const TYPENAME *input, \
device TYPENAME *output, \
- uint tid [[ thread_position_in_grid ]] \
+ uint2 idx [[thread_position_in_grid]] \
) { \
- if (tid >= d1 * d2) { \
- return; \
- } \
- size_t idx1 = tid / d2; \
- size_t idx2 = tid - idx1 * d2; \
- size_t src_idx = idx1 * src_s + idx2; \
- size_t dst_idx = idx1 * dst_s + idx2; \
+ if (idx.x >= d1 || idx.y >= d2) return; \
+ int64_t src_idx = idx.x * src_s + idx.y; \
+ int64_t dst_idx = idx.x * dst_s + idx.y; \
output[dst_idx] = input[src_idx]; \
}