diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-04-08 09:37:25 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-08 09:37:25 +0200 |
commit | 718671a0d5b751458033fb6425fb518ca4dc3b5f (patch) | |
tree | e72ccab88dbb1ee28878664ba2574bdb5563eca1 | |
parent | c5fe4a7f8983ae7c9641fa923f26ef60538aef06 (diff) | |
download | candle-718671a0d5b751458033fb6425fb518ca4dc3b5f.tar.gz candle-718671a0d5b751458033fb6425fb518ca4dc3b5f.tar.bz2 candle-718671a0d5b751458033fb6425fb518ca4dc3b5f.zip |
Use BufferOffset in metal backend ops. (#2029)
* Use BufferOffset in the metal backend.
* More BufferOffset usage.
* Use in where-cond.
-rw-r--r-- | candle-core/src/metal_backend/mod.rs | 89 | ||||
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 155 | ||||
-rw-r--r-- | candle-metal-kernels/src/tests.rs | 51 |
3 files changed, 117 insertions, 178 deletions
diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 4adcda05..50149a9d 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -314,6 +314,7 @@ impl BackendStorage for MetalStorage { let dtype = if return_index { DType::U32 } else { self.dtype }; let buffer = device.new_buffer(dst_el, dtype, "reduce")?; let command_buffer = self.device.command_buffer()?; + let src = buffer_o(&self.buffer, layout, self.dtype); candle_metal_kernels::call_reduce_strided( &device.device, &command_buffer, @@ -322,8 +323,7 @@ impl BackendStorage for MetalStorage { &dims, &stride, dst_el, - &self.buffer, - layout.start_offset() * self.dtype.size_in_bytes(), + src, &buffer, ) .map_err(MetalError::from)?; @@ -617,21 +617,21 @@ impl BackendStorage for MetalStorage { (DType::U8, DType::U8) => "where_u8_u8", (left, right) => crate::bail!("Metal where_cond {left:?} {right:?} not implemented"), }; + let src = buffer_o(&self.buffer, layout, self.dtype); + let t = buffer_o(&t.buffer, t_l, t.dtype); + let f = buffer_o(&f.buffer, f_l, f.dtype); candle_metal_kernels::call_where_cond_strided( &device.device, &command_buffer, &device.kernels, name, dims, - &self.buffer, - ( - layout.stride(), - layout.start_offset() * self.dtype.size_in_bytes(), - ), - &t.buffer, - (t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()), - &f.buffer, - (f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()), + src, + layout.stride(), + t, + t_l.stride(), + f, + f_l.stride(), &buffer, ) .map_err(MetalError::from)?; @@ -664,6 +664,7 @@ impl BackendStorage for MetalStorage { DType::F32 => "im2col1d_f32", dtype => crate::bail!("Metal conv1d {dtype:?} not implemented"), }; + let src = buffer_o(&self.buffer, layout, self.dtype); candle_metal_kernels::call_im2col1d_strided( &self.device.device, &command_buffer, @@ -672,8 +673,7 @@ impl BackendStorage for MetalStorage { layout.shape().dims(), strides, (k_size, stride, padding, dilation), - &self.buffer, - layout.start_offset() * self.dtype.size_in_bytes(), + src, &dst, ) .map_err(MetalError::from)?; @@ -791,6 +791,7 @@ impl BackendStorage for MetalStorage { DType::U32 => "im2col_u32", dtype => crate::bail!("Metal conv2d {dtype:?} not implemented"), }; + let src = buffer_o(&self.buffer, layout, self.dtype); candle_metal_kernels::call_im2col_strided( &self.device.device, &command_buffer, @@ -799,8 +800,7 @@ impl BackendStorage for MetalStorage { layout.shape().dims(), layout.stride(), (h_k, w_k, stride, padding, dilation), - &self.buffer, - layout.start_offset() * self.dtype.size_in_bytes(), + src, &dst, ) .map_err(MetalError::from)?; @@ -1013,6 +1013,7 @@ impl BackendStorage for MetalStorage { .device .new_buffer(dst_el, self.dtype, "upsample_nearest2d")?; let command_buffer = self.device.command_buffer()?; + let src = buffer_o(&self.buffer, inp_l, self.dtype); candle_metal_kernels::call_upsample_nearest_2d( &self.device.device, &command_buffer, @@ -1022,8 +1023,7 @@ impl BackendStorage for MetalStorage { strides, out_w, out_h, - &self.buffer, - inp_l.start_offset() * self.dtype.size_in_bytes(), + src, &buffer, ) .map_err(MetalError::from)?; @@ -1031,9 +1031,8 @@ impl BackendStorage for MetalStorage { } fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> { - let (ids_o1, _) = match ids_l.contiguous_offsets() { - Some(o12) => o12, - None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?, + if !ids_l.is_contiguous() { + return Err(crate::Error::RequiresContiguous { op: "gather" }.bt()); }; let ids_el = ids_l.dims()[dim]; let dst_el = ids_l.shape().elem_count(); @@ -1046,6 +1045,8 @@ impl BackendStorage for MetalStorage { (left, right) => crate::bail!("Metal gather {left:?} {right:?} not implemented"), }; let command_buffer = self.device.command_buffer()?; + let src = buffer_o(&self.buffer, src_l, dtype); + let ids = buffer_o(&ids.buffer, ids_l, ids.dtype); candle_metal_kernels::call_gather( &device.device, &command_buffer, @@ -1054,10 +1055,8 @@ impl BackendStorage for MetalStorage { src_l.dims(), ids_el, dim, - &self.buffer, - src_l.start_offset() * dtype.size_in_bytes(), - &ids.buffer, - ids_o1 * ids.dtype.size_in_bytes(), + src, + ids, &buffer, ) .map_err(MetalError::from)?; @@ -1075,13 +1074,8 @@ impl BackendStorage for MetalStorage { ) -> Result<Self> { let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?; self.copy_strided_src(&mut acc, 0, l)?; - let (ids_offset, _) = match ids_l.contiguous_offsets() { - Some(o12) => o12, - None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?, - }; - let src_offset = match src_l.contiguous_offsets() { - Some((o1, _)) => o1, - None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?, + if !ids_l.is_contiguous() || !src_l.is_contiguous() { + return Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt()); }; let name = match (ids.dtype, self.dtype) { (DType::U8, DType::F32) => "sa_u8_f32", @@ -1100,6 +1094,8 @@ impl BackendStorage for MetalStorage { })?, }; let command_buffer = self.device.command_buffer()?; + let src = buffer_o(&src.buffer, src_l, src.dtype); + let ids = buffer_o(&ids.buffer, ids_l, ids.dtype); candle_metal_kernels::call_scatter_add( &self.device.device, &command_buffer, @@ -1108,10 +1104,8 @@ impl BackendStorage for MetalStorage { src_l.dims(), l.dims(), dim, - &src.buffer, - src_offset * src.dtype.size_in_bytes(), - &ids.buffer, - ids_offset * ids.dtype.size_in_bytes(), + src, + ids, &acc.buffer, ) .map_err(MetalError::from)?; @@ -1147,6 +1141,8 @@ impl BackendStorage for MetalStorage { } }; let command_buffer = self.device.command_buffer()?; + let src = buffer_o(&self.buffer, src_l, dtype); + let ids = buffer_o(&ids.buffer, ids_l, ids.dtype); candle_metal_kernels::call_index_select( &device.device, &command_buffer, @@ -1158,10 +1154,8 @@ impl BackendStorage for MetalStorage { src_l.is_contiguous(), src_l.dims(), src_l.stride(), - &self.buffer, - src_l.start_offset() * dtype.size_in_bytes(), - &ids.buffer, - ids_l.start_offset() * ids.dtype.size_in_bytes(), + src, + ids, &buffer, ) .map_err(MetalError::from)?; @@ -1179,13 +1173,8 @@ impl BackendStorage for MetalStorage { ) -> Result<Self> { let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?; self.copy_strided_src(&mut acc, 0, l)?; - let (ids_offset, _) = match ids_l.contiguous_offsets() { - Some(o12) => o12, - None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?, - }; - let src_offset = match src_l.contiguous_offsets() { - Some((o1, _)) => o1, - None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?, + if !ids_l.is_contiguous() || !src_l.is_contiguous() { + return Err(crate::Error::RequiresContiguous { op: "index-add" }.bt()); }; let name = match (ids.dtype, self.dtype) { (DType::I64, DType::BF16) => "ia_i64_bf16", @@ -1216,6 +1205,8 @@ impl BackendStorage for MetalStorage { })?, }; let command_buffer = self.device.command_buffer()?; + let src = buffer_o(&src.buffer, src_l, src.dtype); + let ids = buffer_o(&ids.buffer, ids_l, ids.dtype); candle_metal_kernels::call_index_add( &self.device.device, &command_buffer, @@ -1225,10 +1216,8 @@ impl BackendStorage for MetalStorage { l.dims(), ids_l.dims(), dim, - &src.buffer, - src_offset * src.dtype.size_in_bytes(), - &ids.buffer, - ids_offset * ids.dtype.size_in_bytes(), + src, + ids, &acc.buffer, ) .map_err(MetalError::from)?; diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 23c072af..78108127 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -503,8 +503,7 @@ pub fn call_reduce_contiguous( kernel_name: &'static str, length: usize, out_length: usize, - input: &Buffer, - input_offset: usize, + input: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; @@ -513,10 +512,7 @@ pub fn call_reduce_contiguous( let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - set_params!( - encoder, - (length, elements_to_sum, (input, input_offset), output) - ); + set_params!(encoder, (length, elements_to_sum, &input, output)); let thread_group_count = MTLSize { width: out_length as u64, @@ -536,7 +532,7 @@ pub fn call_reduce_contiguous( depth: 1, }; - encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -552,8 +548,7 @@ pub fn call_reduce_strided( shape: &[usize], strides: &[usize], out_length: usize, - input: &Buffer, - input_offset: usize, + input: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let length: usize = shape.iter().product(); @@ -565,14 +560,7 @@ pub fn call_reduce_strided( set_params!( encoder, - ( - shape.len(), - shape, - strides, - elements_to_sum, - (input, input_offset), - output - ) + (shape.len(), shape, strides, elements_to_sum, &input, output) ); let thread_group_count = MTLSize { @@ -593,7 +581,7 @@ pub fn call_reduce_strided( depth: 1, }; - encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -1024,12 +1012,12 @@ pub fn call_where_cond_strided( kernels: &Kernels, name: &'static str, shape: &[usize], - cond: &Buffer, - (cond_stride, cond_offset): (&[usize], usize), - left: &Buffer, - (left_stride, left_offset): (&[usize], usize), - right: &Buffer, - (right_stride, right_offset): (&[usize], usize), + cond: BufferOffset, + cond_stride: &[usize], + left: BufferOffset, + left_stride: &[usize], + right: BufferOffset, + right_stride: &[usize], output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?; @@ -1049,18 +1037,18 @@ pub fn call_where_cond_strided( cond_stride, left_stride, right_stride, - (cond, cond_offset), - (left, left_offset), - (right, right_offset), + &cond, + &left, + &right, output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); - encoder.use_resource(cond, metal::MTLResourceUsage::Read); - encoder.use_resource(left, metal::MTLResourceUsage::Read); - encoder.use_resource(right, metal::MTLResourceUsage::Read); + encoder.use_resource(cond.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(left.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -1079,10 +1067,8 @@ pub fn call_index_select( contiguous: bool, src_dims: &[usize], src_strides: &[usize], - input: &Buffer, - src_offset: usize, - ids: &Buffer, - ids_offset: usize, + input: BufferOffset, + ids: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let left_size: usize = shape[..dim].iter().product(); @@ -1107,16 +1093,16 @@ pub fn call_index_select( contiguous, src_dims, src_strides, - (input, src_offset), - (ids, ids_offset), + &input, + &ids, output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); - encoder.use_resource(input, metal::MTLResourceUsage::Read); - encoder.use_resource(ids, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -1132,10 +1118,8 @@ pub fn call_gather( shape: &[usize], ids_size: usize, dim: usize, - input: &Buffer, - input_offset: usize, - ids: &Buffer, - ids_offset: usize, + input: BufferOffset, + ids: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let left_size: usize = shape[..dim].iter().product(); @@ -1157,16 +1141,16 @@ pub fn call_gather( src_dim_size, right_size, ids_size, - (input, input_offset), - (ids, ids_offset), + &input, + &ids, output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); - encoder.use_resource(input, metal::MTLResourceUsage::Read); - encoder.use_resource(ids, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -1182,10 +1166,8 @@ pub fn call_scatter_add( src_shape: &[usize], dst_shape: &[usize], dim: usize, - input: &Buffer, - input_offset: usize, - ids: &Buffer, - ids_offset: usize, + input: BufferOffset, + ids: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let left_size: usize = src_shape[..dim].iter().product(); @@ -1208,16 +1190,16 @@ pub fn call_scatter_add( src_dim_size, right_size, dst_dim_size, - (input, input_offset), - (ids, ids_offset), + &input, + &ids, output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); - encoder.use_resource(input, metal::MTLResourceUsage::Read); - encoder.use_resource(ids, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -1234,10 +1216,8 @@ pub fn call_index_add( dst_shape: &[usize], ids_shape: &[usize], dim: usize, - input: &Buffer, - input_offset: usize, - ids: &Buffer, - ids_offset: usize, + input: BufferOffset, + ids: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let left_size: usize = src_shape[..dim].iter().product(); @@ -1261,16 +1241,16 @@ pub fn call_index_add( right_size, dst_dim_size, ids_dim_size, - (input, input_offset), - (ids, ids_offset), + &input, + &ids, output ) ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); - encoder.use_resource(input, metal::MTLResourceUsage::Read); - encoder.use_resource(ids, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -1536,8 +1516,7 @@ pub fn call_im2col1d_strided( shape: &[usize], strides: &[usize], (k_size, stride, padding, dilation): (usize, usize, usize, usize), - input: &Buffer, - input_offset: usize, + input: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Conv, name)?; @@ -1549,20 +1528,9 @@ pub fn call_im2col1d_strided( encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, - ( - dst_el, - l_out, - k_size, - stride, - padding, - dilation, - shape, - strides, - (input, input_offset), - output - ) + (dst_el, l_out, k_size, stride, padding, dilation, shape, strides, &input, output) ); - encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -1579,8 +1547,7 @@ pub fn call_im2col_strided( shape: &[usize], strides: &[usize], (h_k, w_k, stride, padding, dilation): (usize, usize, usize, usize, usize), - input: &Buffer, - input_offset: usize, + input: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Conv, name)?; @@ -1598,21 +1565,11 @@ pub fn call_im2col_strided( set_params!( encoder, ( - dst_el, - h_out, - w_out, - h_k, - w_k, - stride, - padding, - dilation, - shape, - strides, - (input, input_offset), + dst_el, h_out, w_out, h_k, w_k, stride, padding, dilation, shape, strides, &input, output ) ); - encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -1630,8 +1587,7 @@ pub fn call_upsample_nearest_2d( strides: &[usize], out_w: usize, out_h: usize, - input: &Buffer, - input_offset: usize, + input: BufferOffset, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Conv, name)?; @@ -1643,18 +1599,9 @@ pub fn call_upsample_nearest_2d( encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, - ( - out_w, - out_h, - scale_w, - scale_h, - shape, - strides, - (input, input_offset), - output - ) + (out_w, out_h, scale_w, scale_h, shape, strides, &input, output) ); - encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index b91c92d8..960ae1df 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -728,10 +728,8 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>( true, shape, stride, - &embeddings_buffer, - 0, - &ids_buffer, - 0, + BufferOffset::zero_offset(&embeddings_buffer), + BufferOffset::zero_offset(&ids_buffer), &dst_buffer, ) .unwrap(); @@ -774,10 +772,8 @@ fn run_index_select_strided<T: Clone, I: Clone + std::fmt::Debug>( false, shape, stride, - &embeddings_buffer, - 0, - &ids_buffer, - 0, + BufferOffset::zero_offset(&embeddings_buffer), + BufferOffset::zero_offset(&ids_buffer), &dst_buffer, ) .unwrap(); @@ -819,8 +815,7 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T &dims, &strides, out_length, - &input, - 0, + BufferOffset::zero_offset(&input), &output, ) .unwrap(); @@ -974,18 +969,30 @@ fn run_where_cond<I: Clone, T: Clone>( ); let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options); + let cond = BufferOffset { + buffer: &cond, + offset_in_bytes: cond_offset, + }; + let left = BufferOffset { + buffer: &left, + offset_in_bytes: left_offset, + }; + let right = BufferOffset { + buffer: &right, + offset_in_bytes: cond_offset, + }; call_where_cond_strided( &device, command_buffer, &kernels, name, shape, - &cond, - (&cond_stride, cond_offset), - &left, - (&left_stride, left_offset), - &right, - (&cond_stride, cond_offset), + cond, + &cond_stride, + left, + &left_stride, + right, + &cond_stride, &output, ) .unwrap(); @@ -1250,10 +1257,8 @@ fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>( shape, shape, dim, - &input_buffer, - 0, - &ids_buffer, - 0, + BufferOffset::zero_offset(&input_buffer), + BufferOffset::zero_offset(&ids_buffer), &output, ) .unwrap(); @@ -1355,10 +1360,8 @@ fn run_index_add<T: Clone, I: Clone + std::fmt::Debug>( shape, shape, dim, - &input_buffer, - 0, - &indices_buffer, - 0, + BufferOffset::zero_offset(&input_buffer), + BufferOffset::zero_offset(&indices_buffer), &output, ) .unwrap(); |