diff options
author | Nicolas Patry <patry.nicolas@protonmail.com> | 2023-12-18 10:46:01 +0100 |
---|---|---|
committer | Nicolas Patry <patry.nicolas@protonmail.com> | 2023-12-18 10:46:01 +0100 |
commit | 8bd3d6b94bb2449d056d38d5a42c8e9c762f2d7e (patch) | |
tree | 57c07c91d8924e26c867883d39cf26eebd535b97 | |
parent | 6a3ca7da0cfb06e80d5c75ee98a1291843092e06 (diff) | |
download | candle-8bd3d6b94bb2449d056d38d5a42c8e9c762f2d7e.tar.gz candle-8bd3d6b94bb2449d056d38d5a42c8e9c762f2d7e.tar.bz2 candle-8bd3d6b94bb2449d056d38d5a42c8e9c762f2d7e.zip |
Index add.
-rw-r--r-- | candle-core/src/metal_backend.rs | 49 | ||||
-rw-r--r-- | candle-metal-kernels/src/indexing.metal | 111 | ||||
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 54 |
3 files changed, 151 insertions, 63 deletions
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index b26477fc..21a8967b 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -951,14 +951,49 @@ impl BackendStorage for MetalStorage { fn index_add( &self, - _: &Layout, - _: &Self, - _: &Layout, - _: &Self, - _: &Layout, - _: usize, + l: &Layout, + ids: &Self, + ids_l: &Layout, + src: &Self, + src_l: &Layout, + dim: usize, ) -> Result<Self> { - crate::bail!("index_add metal") + let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?; + self.copy_strided_src(&mut acc, 0, l)?; + let (ids_offset, _) = match ids_l.contiguous_offsets() { + Some(o12) => o12, + None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?, + }; + let src_offset = match src_l.contiguous_offsets() { + Some((o1, _)) => o1, + None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?, + }; + let name = match (ids.dtype, self.dtype) { + (DType::U32, DType::F32) => "ia_u32_f32", + _ => Err(MetalError::UnexpectedDType { + msg: "index-add ids should be u8/u32/i64", + expected: DType::U32, + got: ids.dtype(), + })?, + }; + let command_buffer = self.device.command_buffer()?; + candle_metal_kernels::call_index_add( + &self.device.device, + &command_buffer, + &self.device.kernels, + name, + src_l.dims(), + l.dims(), + ids_l.dims(), + dim, + &src.buffer, + src_offset * src.dtype.size_in_bytes(), + &ids.buffer, + ids_offset * ids.dtype.size_in_bytes(), + &acc.buffer, + ) + .map_err(MetalError::from)?; + Ok(acc) } fn matmul( &self, diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 72a3a348..63357428 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -122,48 +122,47 @@ kernel void NAME( \ scatter_add<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, dst_dim_size, input, input_ids, output, tid); \ } - -template <typename T, typename I> -void index_add( - device I *ids [[buffer(0)]], - device T *inp [[buffer(1)]], - device T *out [[buffer(2)]], - - constant uint &ids_dim_size, - constant uint &left_size, - constant uint &dst_dim_size, - constant uint &right_size, - - uint gid [[ thread_position_in_grid ]] \ -) { - - if (gid >= left_size * right_size) { - return; - } - - const uint i = gid; - const uint pre = i / right_size; - const uint post = i % right_size; - - for (uint j = 0; j < ids_dim_size; j++) { - const uint idx = ids[j]; - const uint src_i = (pre * ids_dim_size + j) * right_size + post; - const uint dst_i = (pre * dst_dim_size + idx) * right_size + post; - out[dst_i] += inp[src_i]; +template<typename TYPENAME, typename INDEX_TYPENAME> +METAL_FUNC void index_add( + constant size_t &dst_size, + constant size_t &left_size, + constant size_t &src_dim_size, + constant size_t &right_size, + constant size_t &dst_dim_size, + constant size_t &ids_dim_size, + const device TYPENAME *input, + const device INDEX_TYPENAME *input_ids, + device TYPENAME *output, + uint tid [[ thread_position_in_grid ]] +) { + if (tid >= dst_size) { + return; + } + const size_t right_rank_i = tid % right_size; + const size_t left_rank_i = tid / right_size; + for (unsigned int j = 0; j < ids_dim_size; ++j) { + const INDEX_TYPENAME idx = input_ids[j]; + const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i; + const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i; + output[dst_i] += input[src_i]; } } -#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \ -kernel void FN_NAME( \ - device INDEX_TYPENAME *ids [[buffer(0)]], \ - device TYPENAME *inp [[buffer(1)]], \ - device TYPENAME *out [[buffer(2)]], \ - constant uint &ids_dim_size, \ - constant uint &left_size, \ - constant uint &dst_dim_size, \ - constant uint &right_size, \ - uint gid [[ thread_position_in_grid ]] \ -) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, gid); } \ +# define INDEX_ADD_OP(NAME, INDEX_TYPENAME, TYPENAME) \ +kernel void NAME( \ + constant size_t &dst_size, \ + constant size_t &left_size, \ + constant size_t &src_dim_size, \ + constant size_t &right_size, \ + constant size_t &dst_dim_size, \ + constant size_t &ids_dim_size, \ + const device TYPENAME *input, \ + const device INDEX_TYPENAME *input_ids, \ + device TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + index_add<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, dst_dim_size, ids_dim_size, input, input_ids, output, tid); \ +} INDEX_OP(is_u32_f32, uint, float) @@ -175,25 +174,25 @@ SCATTER_ADD_OP(sa_u32_f16, uint, half) #if __METAL_VERSION__ >= 310 -IA_OP(bfloat, int64_t, ia_i64_bf16) -IA_OP(bfloat, uint32_t, ia_u32_bf16) -IA_OP(bfloat, uint8_t, ia_u8_bf16) +INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat) +INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat) +INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat) #endif -IA_OP(half, uint32_t, ia_u32_f16) -IA_OP(half, uint8_t, ia_u8_f16) +INDEX_ADD_OP(ia_u32_f16, uint32_t, half) +INDEX_ADD_OP(ia_u8_f16, uint8_t, half) -IA_OP(float, int64_t, ia_i64_f32) -IA_OP(uint8_t, int64_t, ia_i64_u8) -IA_OP(int64_t, int64_t, ia_i64_i64) -IA_OP(uint32_t, int64_t, ia_i64_u32) +INDEX_ADD_OP(ia_i64_f32, int64_t, float) +INDEX_ADD_OP(ia_i64_u8, int64_t, uint8_t) +INDEX_ADD_OP(ia_i64_i64, int64_t, int64_t) +INDEX_ADD_OP(ia_i64_u32, int64_t, uint32_t) -IA_OP(float, uint32_t, ia_u32_f32) -IA_OP(uint8_t, uint32_t, ia_u32_u8) -IA_OP(int64_t, uint32_t, ia_u32_i64) -IA_OP(uint32_t, uint32_t, ia_u32_u32) +INDEX_ADD_OP(ia_u32_f32, uint32_t, float) +INDEX_ADD_OP(ia_u32_u8, uint32_t, uint8_t) +INDEX_ADD_OP(ia_u32_i64, uint32_t, int64_t) +INDEX_ADD_OP(ia_u32_u32, uint32_t, uint32_t) -IA_OP(float, uint8_t, ia_u8_f32) -IA_OP(uint8_t, uint8_t, ia_u8_u8) -IA_OP(uint32_t, uint8_t, ia_u8_u32) -IA_OP(int64_t, uint8_t, ia_u8_i64) +INDEX_ADD_OP(ia_u8_f32, uint8_t, float) +INDEX_ADD_OP(ia_u8_u8, uint8_t, uint8_t) +INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t) +INDEX_ADD_OP(ia_u8_i64, uint8_t, int64_t) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index ddc04d05..0bd7d8cb 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1114,6 +1114,60 @@ pub fn call_scatter_add( Ok(()) } +pub fn call_index_add( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + src_shape: &[usize], + dst_shape: &[usize], + ids_shape: &[usize], + dim: usize, + input: &Buffer, + input_offset: usize, + ids: &Buffer, + ids_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let left_size: usize = src_shape[..dim].iter().product(); + let right_size: usize = src_shape[dim + 1..].iter().product(); + let src_dim_size = src_shape[dim]; + let dst_el = left_size * right_size; + let dst_dim_size = dst_shape[dim]; + let ids_dim_size = ids_shape[0]; + + let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; + let encoder = command_buffer.new_compute_command_encoder(); + + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + dst_el, + left_size, + src_dim_size, + right_size, + dst_dim_size, + ids_dim_size, + (input, input_offset), + (ids, ids_offset), + output + ) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(ids, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + Ok(()) +} + #[derive(Debug, PartialEq)] pub enum Value { USize(usize), |