diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-04-08 09:37:25 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-08 09:37:25 +0200 |
commit | 718671a0d5b751458033fb6425fb518ca4dc3b5f (patch) | |
tree | e72ccab88dbb1ee28878664ba2574bdb5563eca1 /candle-metal-kernels | |
parent | c5fe4a7f8983ae7c9641fa923f26ef60538aef06 (diff) | |
download | candle-718671a0d5b751458033fb6425fb518ca4dc3b5f.tar.gz candle-718671a0d5b751458033fb6425fb518ca4dc3b5f.tar.bz2 candle-718671a0d5b751458033fb6425fb518ca4dc3b5f.zip |
Use BufferOffset in metal backend ops. (#2029)
* Use BufferOffset in the metal backend.
* More BufferOffset usage.
* Use in where-cond.
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 155 | ||||
-rw-r--r-- | candle-metal-kernels/src/tests.rs | 51 |
2 files changed, 78 insertions, 128 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 23c072af..78108127 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -503,8 +503,7 @@ pub fn call_reduce_contiguous( kernel_name: &'static str, length: usize, out_length: usize, - input: &Buffer, - input_offset: usize, + input: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; @@ -513,10 +512,7 @@ pub fn call_reduce_contiguous( let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - set_params!( - encoder, - (length, elements_to_sum, (input, input_offset), output) - ); + set_params!(encoder, (length, elements_to_sum, &input, output)); let thread_group_count = MTLSize { width: out_length as u64, @@ -536,7 +532,7 @@ pub fn call_reduce_contiguous( depth: 1, }; - encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -552,8 +548,7 @@ pub fn call_reduce_strided( shape: &[usize], strides: &[usize], out_length: usize, - input: &Buffer, - input_offset: usize, + input: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let length: usize = shape.iter().product(); @@ -565,14 +560,7 @@ pub fn call_reduce_strided( set_params!( encoder, - ( - shape.len(), - shape, - strides, - elements_to_sum, - (input, input_offset), - output - ) + (shape.len(), shape, strides, elements_to_sum, &input, output) ); let thread_group_count = MTLSize { @@ -593,7 +581,7 @@ pub fn call_reduce_strided( depth: 1, }; - encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -1024,12 +1012,12 @@ pub fn call_where_cond_strided( kernels: &Kernels, name: &'static str, shape: &[usize], - cond: &Buffer, - (cond_stride, cond_offset): (&[usize], usize), - left: &Buffer, - (left_stride, left_offset): (&[usize], usize), - right: &Buffer, - (right_stride, right_offset): (&[usize], usize), + cond: BufferOffset, + cond_stride: &[usize], + left: BufferOffset, + left_stride: &[usize], + right: BufferOffset, + right_stride: &[usize], output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?; @@ -1049,18 +1037,18 @@ pub fn call_where_cond_strided( cond_stride, left_stride, right_stride, - (cond, cond_offset), - (left, left_offset), - (right, right_offset), + &cond, + &left, + &right, output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); - encoder.use_resource(cond, metal::MTLResourceUsage::Read); - encoder.use_resource(left, metal::MTLResourceUsage::Read); - encoder.use_resource(right, metal::MTLResourceUsage::Read); + encoder.use_resource(cond.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(left.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -1079,10 +1067,8 @@ pub fn call_index_select( contiguous: bool, src_dims: &[usize], src_strides: &[usize], - input: &Buffer, - src_offset: usize, - ids: &Buffer, - ids_offset: usize, + input: BufferOffset, + ids: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let left_size: usize = shape[..dim].iter().product(); @@ -1107,16 +1093,16 @@ pub fn call_index_select( contiguous, src_dims, src_strides, - (input, src_offset), - (ids, ids_offset), + &input, + &ids, output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); - encoder.use_resource(input, metal::MTLResourceUsage::Read); - encoder.use_resource(ids, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -1132,10 +1118,8 @@ pub fn call_gather( shape: &[usize], ids_size: usize, dim: usize, - input: &Buffer, - input_offset: usize, - ids: &Buffer, - ids_offset: usize, + input: BufferOffset, + ids: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let left_size: usize = shape[..dim].iter().product(); @@ -1157,16 +1141,16 @@ pub fn call_gather( src_dim_size, right_size, ids_size, - (input, input_offset), - (ids, ids_offset), + &input, + &ids, output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); - encoder.use_resource(input, metal::MTLResourceUsage::Read); - encoder.use_resource(ids, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -1182,10 +1166,8 @@ pub fn call_scatter_add( src_shape: &[usize], dst_shape: &[usize], dim: usize, - input: &Buffer, - input_offset: usize, - ids: &Buffer, - ids_offset: usize, + input: BufferOffset, + ids: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let left_size: usize = src_shape[..dim].iter().product(); @@ -1208,16 +1190,16 @@ pub fn call_scatter_add( src_dim_size, right_size, dst_dim_size, - (input, input_offset), - (ids, ids_offset), + &input, + &ids, output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); - encoder.use_resource(input, metal::MTLResourceUsage::Read); - encoder.use_resource(ids, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -1234,10 +1216,8 @@ pub fn call_index_add( dst_shape: &[usize], ids_shape: &[usize], dim: usize, - input: &Buffer, - input_offset: usize, - ids: &Buffer, - ids_offset: usize, + input: BufferOffset, + ids: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let left_size: usize = src_shape[..dim].iter().product(); @@ -1261,16 +1241,16 @@ pub fn call_index_add( right_size, dst_dim_size, ids_dim_size, - (input, input_offset), - (ids, ids_offset), + &input, + &ids, output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); - encoder.use_resource(input, metal::MTLResourceUsage::Read); - encoder.use_resource(ids, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -1536,8 +1516,7 @@ pub fn call_im2col1d_strided( shape: &[usize], strides: &[usize], (k_size, stride, padding, dilation): (usize, usize, usize, usize), - input: &Buffer, - input_offset: usize, + input: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Conv, name)?; @@ -1549,20 +1528,9 @@ pub fn call_im2col1d_strided( encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, - ( - dst_el, - l_out, - k_size, - stride, - padding, - dilation, - shape, - strides, - (input, input_offset), - output - ) + (dst_el, l_out, k_size, stride, padding, dilation, shape, strides, &input, output) ); - encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -1579,8 +1547,7 @@ pub fn call_im2col_strided( shape: &[usize], strides: &[usize], (h_k, w_k, stride, padding, dilation): (usize, usize, usize, usize, usize), - input: &Buffer, - input_offset: usize, + input: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Conv, name)?; @@ -1598,21 +1565,11 @@ pub fn call_im2col_strided( set_params!( encoder, ( - dst_el, - h_out, - w_out, - h_k, - w_k, - stride, - padding, - dilation, - shape, - strides, - (input, input_offset), + dst_el, h_out, w_out, h_k, w_k, stride, padding, dilation, shape, strides, &input, output ) ); - encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -1630,8 +1587,7 @@ pub fn call_upsample_nearest_2d( strides: &[usize], out_w: usize, out_h: usize, - input: &Buffer, - input_offset: usize, + input: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Conv, name)?; @@ -1643,18 +1599,9 @@ pub fn call_upsample_nearest_2d( encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, - ( - out_w, - out_h, - scale_w, - scale_h, - shape, - strides, - (input, input_offset), - output - ) + (out_w, out_h, scale_w, scale_h, shape, strides, &input, output) ); - encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index b91c92d8..960ae1df 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -728,10 +728,8 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>( true, shape, stride, - &embeddings_buffer, - 0, - &ids_buffer, - 0, + BufferOffset::zero_offset(&embeddings_buffer), + BufferOffset::zero_offset(&ids_buffer), &dst_buffer, ) .unwrap(); @@ -774,10 +772,8 @@ fn run_index_select_strided<T: Clone, I: Clone + std::fmt::Debug>( false, shape, stride, - &embeddings_buffer, - 0, - &ids_buffer, - 0, + BufferOffset::zero_offset(&embeddings_buffer), + BufferOffset::zero_offset(&ids_buffer), &dst_buffer, ) .unwrap(); @@ -819,8 +815,7 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T &dims, &strides, out_length, - &input, - 0, + BufferOffset::zero_offset(&input), &output, ) .unwrap(); @@ -974,18 +969,30 @@ fn run_where_cond<I: Clone, T: Clone>( ); let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options); + let cond = BufferOffset { + buffer: &cond, + offset_in_bytes: cond_offset, + }; + let left = BufferOffset { + buffer: &left, + offset_in_bytes: left_offset, + }; + let right = BufferOffset { + buffer: &right, + offset_in_bytes: cond_offset, + }; call_where_cond_strided( &device, command_buffer, &kernels, name, shape, - &cond, - (&cond_stride, cond_offset), - &left, - (&left_stride, left_offset), - &right, - (&cond_stride, cond_offset), + cond, + &cond_stride, + left, + &left_stride, + right, + &cond_stride, &output, ) .unwrap(); @@ -1250,10 +1257,8 @@ fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>( shape, shape, dim, - &input_buffer, - 0, - &ids_buffer, - 0, + BufferOffset::zero_offset(&input_buffer), + BufferOffset::zero_offset(&ids_buffer), &output, ) .unwrap(); @@ -1355,10 +1360,8 @@ fn run_index_add<T: Clone, I: Clone + std::fmt::Debug>( shape, shape, dim, - &input_buffer, - 0, - &indices_buffer, - 0, + BufferOffset::zero_offset(&input_buffer), + BufferOffset::zero_offset(&indices_buffer), &output, ) .unwrap(); |