summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-07 22:37:53 +0200
committerGitHub <noreply@github.com>2024-04-07 22:37:53 +0200
commitc5fe4a7f8983ae7c9641fa923f26ef60538aef06 (patch)
tree12ad3e2445577fc77a5f9879686d554aea943a0d /candle-metal-kernels
parent7f354473cf495db4554e08f84be44ed498f1aa5e (diff)
downloadcandle-c5fe4a7f8983ae7c9641fa923f26ef60538aef06.tar.gz
candle-c5fe4a7f8983ae7c9641fa923f26ef60538aef06.tar.bz2
candle-c5fe4a7f8983ae7c9641fa923f26ef60538aef06.zip
Rework the buffer offset logic for metal kernels (#2028)
* Move the metal kernels utils in a separate module. * Use the BufferOffset for unary ops. * Fix clippy lints. * Use the new BufferOffset. * Adapt the binary ops. * Affine. * More ops (powf, elu, cast).
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/lib.rs289
-rw-r--r--candle-metal-kernels/src/tests.rs58
-rw-r--r--candle-metal-kernels/src/utils.rs162
3 files changed, 262 insertions, 247 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 8b9be670..23c072af 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -1,11 +1,15 @@
use metal::{
- Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
- Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
+ Buffer, CommandBufferRef, CompileOptions, ComputePipelineState, Device, Function,
+ FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
};
use std::collections::HashMap;
use std::ffi::c_void;
use std::sync::RwLock;
+mod utils;
+pub use utils::BufferOffset;
+use utils::{get_block_dims, linear_split};
+
const AFFINE: &str = include_str!("affine.metal");
const INDEXING: &str = include_str!("indexing.metal");
const UNARY: &str = include_str!("unary.metal");
@@ -18,138 +22,6 @@ const RANDOM: &str = include_str!("random.metal");
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
const QUANTIZED: &str = include_str!("quantized.metal");
-/// Most kernels apply similarly across the tensors
-/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
-/// actual total buffer length).
-/// Then kernels can just do their op on their single point in the buffer.
-fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
- let size = length as u64;
- let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size);
- let count = (size + width - 1) / width;
- let thread_group_count = MTLSize {
- width: count,
- height: 1,
- depth: 1,
- };
-
- let thread_group_size = MTLSize {
- width,
- height: 1,
- depth: 1,
- };
- (thread_group_count, thread_group_size)
-}
-
-// https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96
-fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
- let mut pows0 = 0u64;
- let mut pows1 = 0u64;
- let mut pows2 = 0u64;
- let mut sum = 0u64;
- loop {
- let presum = sum;
- // Check all the pows
- if dim0 >= (1 << (pows0 + 1)) {
- pows0 += 1;
- sum += 1;
- }
- if sum == 10 {
- break;
- }
- if dim1 >= (1 << (pows1 + 1)) {
- pows1 += 1;
- sum += 1;
- }
- if sum == 10 {
- break;
- }
- if dim2 >= (1 << (pows2 + 1)) {
- pows2 += 1;
- sum += 1;
- }
- if sum == presum || sum == 10 {
- break;
- }
- }
- MTLSize {
- width: 1 << pows0,
- height: 1 << pows1,
- depth: 1 << pows2,
- }
-}
-
-fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {
- <P as EncoderParam>::set_param(encoder, position, data)
-}
-
-/// Helper functions to create the various objects on the compute command encoder
-/// on a single line.
-/// Prevents getting wrong some arguments number and mixing length and size in bytes.
-trait EncoderParam {
- fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
-}
-macro_rules! primitive {
- ($type:ty) => {
- impl EncoderParam for $type {
- fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
- encoder.set_bytes(
- position,
- core::mem::size_of::<$type>() as u64,
- &data as *const $type as *const c_void,
- );
- }
- }
- };
-}
-primitive!(bool);
-primitive!(usize);
-primitive!(i32);
-primitive!(i64);
-primitive!(u32);
-primitive!(u64);
-primitive!(f32);
-
-impl<T> EncoderParam for &[T] {
- fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
- encoder.set_bytes(
- position,
- core::mem::size_of_val(data) as u64,
- data.as_ptr() as *const c_void,
- );
- }
-}
-
-impl EncoderParam for &Buffer {
- fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
- encoder.set_buffer(position, Some(data), 0);
- }
-}
-impl EncoderParam for (&Buffer, usize) {
- fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
- encoder.set_buffer(position, Some(data.0), data.1 as u64);
- }
-}
-impl EncoderParam for &mut Buffer {
- fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
- encoder.set_buffer(position, Some(data), 0);
- }
-}
-impl EncoderParam for (&mut Buffer, usize) {
- fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
- encoder.set_buffer(position, Some(data.0), data.1 as u64);
- }
-}
-
-macro_rules! set_params {
- ($encoder:ident, ($($param:expr),+)) => (
- let mut _index = 0;
- $(
- set_param($encoder, _index, $param);
- _index += 1;
- )*
- );
-}
-
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Source {
Affine,
@@ -273,6 +145,12 @@ pub struct Kernels {
pipelines: RwLock<Pipelines>,
}
+impl Default for Kernels {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
impl Kernels {
pub fn new() -> Self {
let libraries = RwLock::new(Libraries::new());
@@ -396,17 +274,17 @@ pub fn call_unary_contiguous(
kernels: &Kernels,
kernel_name: unary::contiguous::Kernel,
length: usize,
- input: &Buffer,
+ input: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
- set_params!(encoder, (length, input, output));
+ set_params!(encoder, (length, &input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
- encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@@ -463,11 +341,9 @@ pub fn call_unary_strided(
kernels: &Kernels,
name: unary::strided::Kernel,
shape: &[usize],
- input: &Buffer,
+ input: BufferOffset,
strides: &[usize],
- offset: usize,
- output: &Buffer,
- output_offset: usize,
+ output: BufferOffset,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
@@ -476,23 +352,13 @@ pub fn call_unary_strided(
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
- set_params!(
- encoder,
- (
- length,
- num_dims,
- shape,
- strides,
- (input, offset),
- (output, output_offset)
- )
- );
+ set_params!(encoder, (length, num_dims, shape, strides, &input, &output));
let width: usize = shape.iter().product();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
- encoder.use_resource(input, metal::MTLResourceUsage::Read);
- encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
@@ -505,8 +371,8 @@ pub fn call_binary_contiguous(
kernels: &Kernels,
kernel_name: binary::contiguous::Kernel,
length: usize,
- left: &Buffer,
- right: &Buffer,
+ left: BufferOffset,
+ right: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
@@ -514,12 +380,12 @@ pub fn call_binary_contiguous(
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
- set_params!(encoder, (length, left, right, output));
+ set_params!(encoder, (length, &left, &right, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
- encoder.use_resource(left, metal::MTLResourceUsage::Read);
- encoder.use_resource(right, metal::MTLResourceUsage::Read);
+ encoder.use_resource(left.buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@@ -533,12 +399,10 @@ pub fn call_binary_strided(
kernels: &Kernels,
name: binary::strided::Kernel,
shape: &[usize],
- left_input: &Buffer,
+ left_input: BufferOffset,
left_strides: &[usize],
- left_offset: usize,
- right_input: &Buffer,
+ right_input: BufferOffset,
right_strides: &[usize],
- right_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
@@ -558,16 +422,16 @@ pub fn call_binary_strided(
shape,
left_strides,
right_strides,
- (left_input, left_offset),
- (right_input, right_offset),
+ &left_input,
+ &right_input,
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
- encoder.use_resource(left_input, metal::MTLResourceUsage::Read);
- encoder.use_resource(right_input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(left_input.buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(right_input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@@ -581,8 +445,7 @@ pub fn call_cast_contiguous(
kernels: &Kernels,
kernel_name: &'static str,
length: usize,
- input: &Buffer,
- input_offset: usize,
+ input: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
@@ -590,10 +453,10 @@ pub fn call_cast_contiguous(
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
- set_params!(encoder, (length, (input, input_offset), output));
+ set_params!(encoder, (length, &input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
- encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@@ -607,9 +470,8 @@ pub fn call_cast_strided(
kernels: &Kernels,
kernel_name: &'static str,
shape: &[usize],
- input: &Buffer,
+ input: BufferOffset,
input_strides: &[usize],
- input_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
@@ -621,25 +483,19 @@ pub fn call_cast_strided(
set_params!(
encoder,
- (
- length,
- shape.len(),
- shape,
- input_strides,
- (input, input_offset),
- output
- )
+ (length, shape.len(), shape, input_strides, &input, output)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
- encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
+#[allow(clippy::too_many_arguments)]
pub fn call_reduce_contiguous(
device: &Device,
command_buffer: &CommandBufferRef,
@@ -687,6 +543,7 @@ pub fn call_reduce_contiguous(
Ok(())
}
+#[allow(clippy::too_many_arguments)]
pub fn call_reduce_strided(
device: &Device,
command_buffer: &CommandBufferRef,
@@ -985,7 +842,7 @@ pub fn call_affine(
kernels: &Kernels,
name: &'static str,
size: usize,
- input: &Buffer,
+ input: BufferOffset,
output: &Buffer,
mul: f32,
add: f32,
@@ -995,10 +852,10 @@ pub fn call_affine(
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
- set_params!(encoder, (size, mul, add, input, output));
+ set_params!(encoder, (size, mul, add, &input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
- encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@@ -1012,9 +869,8 @@ pub fn call_affine_strided(
kernels: &Kernels,
name: &'static str,
shape: &[usize],
- input: &Buffer,
+ input: BufferOffset,
input_stride: &[usize],
- input_offset: usize,
output: &Buffer,
mul: f32,
add: f32,
@@ -1034,13 +890,13 @@ pub fn call_affine_strided(
input_stride,
mul,
add,
- (input, input_offset),
+ &input,
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
- encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@@ -1054,7 +910,7 @@ pub fn call_powf(
kernels: &Kernels,
name: &'static str,
size: usize,
- input: &Buffer,
+ input: BufferOffset,
output: &Buffer,
mul: f32,
) -> Result<(), MetalKernelError> {
@@ -1063,10 +919,10 @@ pub fn call_powf(
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
- set_params!(encoder, (size, mul, input, output));
+ set_params!(encoder, (size, mul, &input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
- encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@@ -1080,9 +936,8 @@ pub fn call_powf_strided(
kernels: &Kernels,
name: &'static str,
shape: &[usize],
- input: &Buffer,
+ input: BufferOffset,
input_stride: &[usize],
- input_offset: usize,
output: &Buffer,
mul: f32,
) -> Result<(), MetalKernelError> {
@@ -1094,19 +949,11 @@ pub fn call_powf_strided(
set_params!(
encoder,
- (
- size,
- shape.len(),
- shape,
- input_stride,
- mul,
- (input, input_offset),
- output
- )
+ (size, shape.len(), shape, input_stride, mul, &input, output)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
- encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@@ -1120,7 +967,7 @@ pub fn call_elu(
kernels: &Kernels,
name: &'static str,
size: usize,
- input: &Buffer,
+ input: BufferOffset,
output: &Buffer,
mul: f32,
) -> Result<(), MetalKernelError> {
@@ -1129,10 +976,10 @@ pub fn call_elu(
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
- set_params!(encoder, (size, mul, input, output));
+ set_params!(encoder, (size, mul, &input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
- encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@@ -1146,9 +993,8 @@ pub fn call_elu_strided(
kernels: &Kernels,
name: &'static str,
shape: &[usize],
- input: &Buffer,
+ input: BufferOffset,
input_stride: &[usize],
- input_offset: usize,
output: &Buffer,
mul: f32,
) -> Result<(), MetalKernelError> {
@@ -1160,25 +1006,18 @@ pub fn call_elu_strided(
set_params!(
encoder,
- (
- size,
- shape.len(),
- shape,
- input_stride,
- mul,
- (input, input_offset),
- output
- )
+ (size, shape.len(), shape, input_stride, mul, &input, output)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
- encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
+#[allow(clippy::too_many_arguments)]
pub fn call_where_cond_strided(
device: &Device,
command_buffer: &CommandBufferRef,
@@ -1334,6 +1173,7 @@ pub fn call_gather(
Ok(())
}
+#[allow(clippy::too_many_arguments)]
pub fn call_scatter_add(
device: &Device,
command_buffer: &CommandBufferRef,
@@ -1384,6 +1224,7 @@ pub fn call_scatter_add(
Ok(())
}
+#[allow(clippy::too_many_arguments)]
pub fn call_index_add(
device: &Device,
command_buffer: &CommandBufferRef,
@@ -1910,6 +1751,7 @@ pub enum GgmlDType {
F32,
}
+#[allow(clippy::too_many_arguments)]
pub fn call_quantized_matmul_t(
device: &Device,
command_buffer: &CommandBufferRef,
@@ -1925,16 +1767,16 @@ pub fn call_quantized_matmul_t(
let ne00 = k as i64;
let ne01 = n as i64;
let ne02 = b as i64;
- let ne03 = 1 as i64;
+ let ne03 = 1i64;
let nb00 = 0i64;
- let nb01 = 0 as i64;
- let nb02 = 0 as i64;
+ let nb01 = 0i64;
+ let nb02 = 0i64;
let ne10 = k as i64;
let ne11 = m as i64;
let ne12 = b as i64;
- let ne13 = 1 as i64;
+ let ne13 = 1i64;
let nb10 = 0i64;
let nb11 = 0i64;
@@ -2169,6 +2011,7 @@ pub struct CallConvTranspose2dCfg<'a> {
pub kernel_offset: usize,
}
+#[allow(clippy::too_many_arguments)]
pub fn call_conv_transpose2d(
device: &Device,
command_buffer: &CommandBufferRef,
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs
index b15d9b36..b91c92d8 100644
--- a/candle-metal-kernels/src/tests.rs
+++ b/candle-metal-kernels/src/tests.rs
@@ -12,7 +12,7 @@ fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
let options = MTLResourceOptions::StorageModeManaged;
let ptr = data.as_ptr() as *const c_void;
- let size = (data.len() * std::mem::size_of::<T>()) as u64;
+ let size = std::mem::size_of_val(data) as u64;
device.new_buffer_with_data(ptr, size, options)
}
@@ -41,6 +41,10 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
+ let input = BufferOffset {
+ buffer: &input,
+ offset_in_bytes: 0,
+ };
let output = new_buffer(&device, v);
call_unary_contiguous(
&device,
@@ -48,7 +52,7 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
&kernels,
name,
v.len(),
- &input,
+ input,
&output,
)
.unwrap();
@@ -72,8 +76,8 @@ fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> V
&kernels,
name,
x.len(),
- &left,
- &right,
+ BufferOffset::zero_offset(&left),
+ BufferOffset::zero_offset(&right),
&output,
)
.unwrap();
@@ -93,7 +97,15 @@ fn run_strided<T: Clone>(
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
- let output = new_buffer(&device, v);
+ let input = BufferOffset {
+ buffer: &input,
+ offset_in_bytes: offset,
+ };
+ let output_b = new_buffer(&device, v);
+ let output = BufferOffset {
+ buffer: &output_b,
+ offset_in_bytes: 0,
+ };
let kernels = Kernels::new();
call_unary_strided(
&device,
@@ -101,16 +113,14 @@ fn run_strided<T: Clone>(
&kernels,
kernel,
shape,
- &input,
+ input,
strides,
- offset,
- &output,
- 0,
+ output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
- read_to_vec(&output, v.len())
+ read_to_vec(&output_b, v.len())
}
#[test]
@@ -308,8 +318,7 @@ fn run_cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
&kernels,
name,
v.len(),
- &input,
- 0,
+ BufferOffset::zero_offset(&input),
&output,
)
.unwrap();
@@ -521,7 +530,7 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
&kernels,
"affine_f32",
size,
- &input,
+ BufferOffset::zero_offset(&input),
&output,
mul as f32,
add as f32,
@@ -554,9 +563,8 @@ fn run_affine_strided<T: Clone>(
&kernels,
"affine_f32_strided",
shape,
- &input,
+ BufferOffset::zero_offset(&input),
strides,
- 0,
&output,
mul as f32,
add as f32,
@@ -633,7 +641,7 @@ fn index_select_strided() {
fn index_select_f16() {
let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
.into_iter()
- .map(|x| f16::from_f32(x))
+ .map(f16::from_f32)
.collect();
let shape = [5, 2];
let stride = [2, 1];
@@ -700,8 +708,8 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
- let embeddings_buffer = new_buffer(&device, &embeddings);
- let ids_buffer = new_buffer(&device, &ids);
+ let embeddings_buffer = new_buffer(&device, embeddings);
+ let ids_buffer = new_buffer(&device, ids);
let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product();
@@ -711,7 +719,7 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
let kernels = Kernels::new();
call_index_select(
&device,
- &command_buffer,
+ command_buffer,
&kernels,
name,
shape,
@@ -746,8 +754,8 @@ fn run_index_select_strided<T: Clone, I: Clone + std::fmt::Debug>(
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
- let embeddings_buffer = new_buffer(&device, &embeddings);
- let ids_buffer = new_buffer(&device, &ids);
+ let embeddings_buffer = new_buffer(&device, embeddings);
+ let ids_buffer = new_buffer(&device, ids);
let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product();
@@ -757,7 +765,7 @@ fn run_index_select_strided<T: Clone, I: Clone + std::fmt::Debug>(
let kernels = Kernels::new();
call_index_select(
&device,
- &command_buffer,
+ command_buffer,
&kernels,
name,
shape,
@@ -931,6 +939,7 @@ fn softmax() {
);
}
+#[allow(clippy::too_many_arguments)]
fn run_where_cond<I: Clone, T: Clone>(
shape: &[usize],
cond: &[I],
@@ -1148,7 +1157,7 @@ fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b:
#[test]
fn random() {
fn calc_mean(data: &[f32]) -> f32 {
- let sum = data.iter().sum::<f32>() as f32;
+ let sum = data.iter().sum::<f32>();
let count = data.len();
assert!(count > 0);
sum / count as f32
@@ -1162,7 +1171,7 @@ fn random() {
let variance = data
.iter()
.map(|value| {
- let diff = mean - (*value as f32);
+ let diff = mean - *value;
diff * diff
})
.sum::<f32>()
@@ -1787,6 +1796,7 @@ fn avg_pool2d_u32() {
assert_eq!(results, expected);
}
+#[allow(clippy::too_many_arguments)]
fn run_conv_transpose1d<T: Clone>(
input: &[T],
input_shape: &[usize],
diff --git a/candle-metal-kernels/src/utils.rs b/candle-metal-kernels/src/utils.rs
new file mode 100644
index 00000000..194cddf4
--- /dev/null
+++ b/candle-metal-kernels/src/utils.rs
@@ -0,0 +1,162 @@
+use metal::{Buffer, ComputeCommandEncoderRef, ComputePipelineState, MTLSize};
+use std::ffi::c_void;
+
+/// Most kernels apply similarly across the tensors
+/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
+/// actual total buffer length).
+/// Then kernels can just do their op on their single point in the buffer.
+pub(crate) fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
+ let size = length as u64;
+ let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size);
+ let count = (size + width - 1) / width;
+ let thread_group_count = MTLSize {
+ width: count,
+ height: 1,
+ depth: 1,
+ };
+
+ let thread_group_size = MTLSize {
+ width,
+ height: 1,
+ depth: 1,
+ };
+ (thread_group_count, thread_group_size)
+}
+
+// https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96
+pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
+ let mut pows0 = 0u64;
+ let mut pows1 = 0u64;
+ let mut pows2 = 0u64;
+ let mut sum = 0u64;
+ loop {
+ let presum = sum;
+ // Check all the pows
+ if dim0 >= (1 << (pows0 + 1)) {
+ pows0 += 1;
+ sum += 1;
+ }
+ if sum == 10 {
+ break;
+ }
+ if dim1 >= (1 << (pows1 + 1)) {
+ pows1 += 1;
+ sum += 1;
+ }
+ if sum == 10 {
+ break;
+ }
+ if dim2 >= (1 << (pows2 + 1)) {
+ pows2 += 1;
+ sum += 1;
+ }
+ if sum == presum || sum == 10 {
+ break;
+ }
+ }
+ MTLSize {
+ width: 1 << pows0,
+ height: 1 << pows1,
+ depth: 1 << pows2,
+ }
+}
+
+pub(crate) fn set_param<P: EncoderParam>(
+ encoder: &ComputeCommandEncoderRef,
+ position: u64,
+ data: P,
+) {
+ <P as EncoderParam>::set_param(encoder, position, data)
+}
+
+/// Helper functions to create the various objects on the compute command encoder
+/// on a single line.
+/// Prevents getting wrong some arguments number and mixing length and size in bytes.
+pub(crate) trait EncoderParam {
+ fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
+}
+macro_rules! primitive {
+ ($type:ty) => {
+ impl EncoderParam for $type {
+ fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
+ encoder.set_bytes(
+ position,
+ core::mem::size_of::<$type>() as u64,
+ &data as *const $type as *const c_void,
+ );
+ }
+ }
+ };
+}
+primitive!(bool);
+primitive!(usize);
+primitive!(i32);
+primitive!(i64);
+primitive!(u32);
+primitive!(u64);
+primitive!(f32);
+
+pub struct BufferOffset<'a> {
+ pub buffer: &'a Buffer,
+ pub offset_in_bytes: usize,
+}
+
+impl<'a> BufferOffset<'a> {
+ pub fn zero_offset(buffer: &'a Buffer) -> Self {
+ Self {
+ buffer,
+ offset_in_bytes: 0,
+ }
+ }
+}
+
+impl<T> EncoderParam for &[T] {
+ fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
+ encoder.set_bytes(
+ position,
+ core::mem::size_of_val(data) as u64,
+ data.as_ptr() as *const c_void,
+ );
+ }
+}
+
+impl EncoderParam for &Buffer {
+ fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
+ encoder.set_buffer(position, Some(data), 0);
+ }
+}
+
+impl EncoderParam for (&Buffer, usize) {
+ fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
+ encoder.set_buffer(position, Some(data.0), data.1 as u64);
+ }
+}
+
+impl<'a> EncoderParam for &BufferOffset<'a> {
+ fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
+ encoder.set_buffer(position, Some(data.buffer), data.offset_in_bytes as u64);
+ }
+}
+
+impl EncoderParam for &mut Buffer {
+ fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
+ encoder.set_buffer(position, Some(data), 0);
+ }
+}
+
+impl EncoderParam for (&mut Buffer, usize) {
+ fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
+ encoder.set_buffer(position, Some(data.0), data.1 as u64);
+ }
+}
+
+#[macro_export]
+macro_rules! set_params {
+ ($encoder:ident, ($($param:expr),+)) => (
+ let mut _index = 0;
+ $(
+ $crate::utils::set_param($encoder, _index, $param);
+ _index += 1;
+ )*
+ );
+}