summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-08 09:37:25 +0200
committerGitHub <noreply@github.com>2024-04-08 09:37:25 +0200
commit718671a0d5b751458033fb6425fb518ca4dc3b5f (patch)
treee72ccab88dbb1ee28878664ba2574bdb5563eca1 /candle-metal-kernels
parentc5fe4a7f8983ae7c9641fa923f26ef60538aef06 (diff)
downloadcandle-718671a0d5b751458033fb6425fb518ca4dc3b5f.tar.gz
candle-718671a0d5b751458033fb6425fb518ca4dc3b5f.tar.bz2
candle-718671a0d5b751458033fb6425fb518ca4dc3b5f.zip
Use BufferOffset in metal backend ops. (#2029)
* Use BufferOffset in the metal backend. * More BufferOffset usage. * Use in where-cond.
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/lib.rs155
-rw-r--r--candle-metal-kernels/src/tests.rs51
2 files changed, 78 insertions, 128 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 23c072af..78108127 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -503,8 +503,7 @@ pub fn call_reduce_contiguous(
kernel_name: &'static str,
length: usize,
out_length: usize,
- input: &Buffer,
- input_offset: usize,
+ input: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
@@ -513,10 +512,7 @@ pub fn call_reduce_contiguous(
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
- set_params!(
- encoder,
- (length, elements_to_sum, (input, input_offset), output)
- );
+ set_params!(encoder, (length, elements_to_sum, &input, output));
let thread_group_count = MTLSize {
width: out_length as u64,
@@ -536,7 +532,7 @@ pub fn call_reduce_contiguous(
depth: 1,
};
- encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@@ -552,8 +548,7 @@ pub fn call_reduce_strided(
shape: &[usize],
strides: &[usize],
out_length: usize,
- input: &Buffer,
- input_offset: usize,
+ input: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let length: usize = shape.iter().product();
@@ -565,14 +560,7 @@ pub fn call_reduce_strided(
set_params!(
encoder,
- (
- shape.len(),
- shape,
- strides,
- elements_to_sum,
- (input, input_offset),
- output
- )
+ (shape.len(), shape, strides, elements_to_sum, &input, output)
);
let thread_group_count = MTLSize {
@@ -593,7 +581,7 @@ pub fn call_reduce_strided(
depth: 1,
};
- encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@@ -1024,12 +1012,12 @@ pub fn call_where_cond_strided(
kernels: &Kernels,
name: &'static str,
shape: &[usize],
- cond: &Buffer,
- (cond_stride, cond_offset): (&[usize], usize),
- left: &Buffer,
- (left_stride, left_offset): (&[usize], usize),
- right: &Buffer,
- (right_stride, right_offset): (&[usize], usize),
+ cond: BufferOffset,
+ cond_stride: &[usize],
+ left: BufferOffset,
+ left_stride: &[usize],
+ right: BufferOffset,
+ right_stride: &[usize],
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
@@ -1049,18 +1037,18 @@ pub fn call_where_cond_strided(
cond_stride,
left_stride,
right_stride,
- (cond, cond_offset),
- (left, left_offset),
- (right, right_offset),
+ &cond,
+ &left,
+ &right,
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
- encoder.use_resource(cond, metal::MTLResourceUsage::Read);
- encoder.use_resource(left, metal::MTLResourceUsage::Read);
- encoder.use_resource(right, metal::MTLResourceUsage::Read);
+ encoder.use_resource(cond.buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(left.buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@@ -1079,10 +1067,8 @@ pub fn call_index_select(
contiguous: bool,
src_dims: &[usize],
src_strides: &[usize],
- input: &Buffer,
- src_offset: usize,
- ids: &Buffer,
- ids_offset: usize,
+ input: BufferOffset,
+ ids: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let left_size: usize = shape[..dim].iter().product();
@@ -1107,16 +1093,16 @@ pub fn call_index_select(
contiguous,
src_dims,
src_strides,
- (input, src_offset),
- (ids, ids_offset),
+ &input,
+ &ids,
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
- encoder.use_resource(input, metal::MTLResourceUsage::Read);
- encoder.use_resource(ids, metal::MTLResourceUsage::Read);
+ encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@@ -1132,10 +1118,8 @@ pub fn call_gather(
shape: &[usize],
ids_size: usize,
dim: usize,
- input: &Buffer,
- input_offset: usize,
- ids: &Buffer,
- ids_offset: usize,
+ input: BufferOffset,
+ ids: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let left_size: usize = shape[..dim].iter().product();
@@ -1157,16 +1141,16 @@ pub fn call_gather(
src_dim_size,
right_size,
ids_size,
- (input, input_offset),
- (ids, ids_offset),
+ &input,
+ &ids,
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
- encoder.use_resource(input, metal::MTLResourceUsage::Read);
- encoder.use_resource(ids, metal::MTLResourceUsage::Read);
+ encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@@ -1182,10 +1166,8 @@ pub fn call_scatter_add(
src_shape: &[usize],
dst_shape: &[usize],
dim: usize,
- input: &Buffer,
- input_offset: usize,
- ids: &Buffer,
- ids_offset: usize,
+ input: BufferOffset,
+ ids: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let left_size: usize = src_shape[..dim].iter().product();
@@ -1208,16 +1190,16 @@ pub fn call_scatter_add(
src_dim_size,
right_size,
dst_dim_size,
- (input, input_offset),
- (ids, ids_offset),
+ &input,
+ &ids,
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
- encoder.use_resource(input, metal::MTLResourceUsage::Read);
- encoder.use_resource(ids, metal::MTLResourceUsage::Read);
+ encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@@ -1234,10 +1216,8 @@ pub fn call_index_add(
dst_shape: &[usize],
ids_shape: &[usize],
dim: usize,
- input: &Buffer,
- input_offset: usize,
- ids: &Buffer,
- ids_offset: usize,
+ input: BufferOffset,
+ ids: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let left_size: usize = src_shape[..dim].iter().product();
@@ -1261,16 +1241,16 @@ pub fn call_index_add(
right_size,
dst_dim_size,
ids_dim_size,
- (input, input_offset),
- (ids, ids_offset),
+ &input,
+ &ids,
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
- encoder.use_resource(input, metal::MTLResourceUsage::Read);
- encoder.use_resource(ids, metal::MTLResourceUsage::Read);
+ encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@@ -1536,8 +1516,7 @@ pub fn call_im2col1d_strided(
shape: &[usize],
strides: &[usize],
(k_size, stride, padding, dilation): (usize, usize, usize, usize),
- input: &Buffer,
- input_offset: usize,
+ input: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
@@ -1549,20 +1528,9 @@ pub fn call_im2col1d_strided(
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
- (
- dst_el,
- l_out,
- k_size,
- stride,
- padding,
- dilation,
- shape,
- strides,
- (input, input_offset),
- output
- )
+ (dst_el, l_out, k_size, stride, padding, dilation, shape, strides, &input, output)
);
- encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@@ -1579,8 +1547,7 @@ pub fn call_im2col_strided(
shape: &[usize],
strides: &[usize],
(h_k, w_k, stride, padding, dilation): (usize, usize, usize, usize, usize),
- input: &Buffer,
- input_offset: usize,
+ input: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
@@ -1598,21 +1565,11 @@ pub fn call_im2col_strided(
set_params!(
encoder,
(
- dst_el,
- h_out,
- w_out,
- h_k,
- w_k,
- stride,
- padding,
- dilation,
- shape,
- strides,
- (input, input_offset),
+ dst_el, h_out, w_out, h_k, w_k, stride, padding, dilation, shape, strides, &input,
output
)
);
- encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
@@ -1630,8 +1587,7 @@ pub fn call_upsample_nearest_2d(
strides: &[usize],
out_w: usize,
out_h: usize,
- input: &Buffer,
- input_offset: usize,
+ input: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
@@ -1643,18 +1599,9 @@ pub fn call_upsample_nearest_2d(
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
- (
- out_w,
- out_h,
- scale_w,
- scale_h,
- shape,
- strides,
- (input, input_offset),
- output
- )
+ (out_w, out_h, scale_w, scale_h, shape, strides, &input, output)
);
- encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs
index b91c92d8..960ae1df 100644
--- a/candle-metal-kernels/src/tests.rs
+++ b/candle-metal-kernels/src/tests.rs
@@ -728,10 +728,8 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
true,
shape,
stride,
- &embeddings_buffer,
- 0,
- &ids_buffer,
- 0,
+ BufferOffset::zero_offset(&embeddings_buffer),
+ BufferOffset::zero_offset(&ids_buffer),
&dst_buffer,
)
.unwrap();
@@ -774,10 +772,8 @@ fn run_index_select_strided<T: Clone, I: Clone + std::fmt::Debug>(
false,
shape,
stride,
- &embeddings_buffer,
- 0,
- &ids_buffer,
- 0,
+ BufferOffset::zero_offset(&embeddings_buffer),
+ BufferOffset::zero_offset(&ids_buffer),
&dst_buffer,
)
.unwrap();
@@ -819,8 +815,7 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
&dims,
&strides,
out_length,
- &input,
- 0,
+ BufferOffset::zero_offset(&input),
&output,
)
.unwrap();
@@ -974,18 +969,30 @@ fn run_where_cond<I: Clone, T: Clone>(
);
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
+ let cond = BufferOffset {
+ buffer: &cond,
+ offset_in_bytes: cond_offset,
+ };
+ let left = BufferOffset {
+ buffer: &left,
+ offset_in_bytes: left_offset,
+ };
+ let right = BufferOffset {
+ buffer: &right,
+ offset_in_bytes: cond_offset,
+ };
call_where_cond_strided(
&device,
command_buffer,
&kernels,
name,
shape,
- &cond,
- (&cond_stride, cond_offset),
- &left,
- (&left_stride, left_offset),
- &right,
- (&cond_stride, cond_offset),
+ cond,
+ &cond_stride,
+ left,
+ &left_stride,
+ right,
+ &cond_stride,
&output,
)
.unwrap();
@@ -1250,10 +1257,8 @@ fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>(
shape,
shape,
dim,
- &input_buffer,
- 0,
- &ids_buffer,
- 0,
+ BufferOffset::zero_offset(&input_buffer),
+ BufferOffset::zero_offset(&ids_buffer),
&output,
)
.unwrap();
@@ -1355,10 +1360,8 @@ fn run_index_add<T: Clone, I: Clone + std::fmt::Debug>(
shape,
shape,
dim,
- &input_buffer,
- 0,
- &indices_buffer,
- 0,
+ BufferOffset::zero_offset(&input_buffer),
+ BufferOffset::zero_offset(&indices_buffer),
&output,
)
.unwrap();