diff options
author | Thomas Santerre <thomas@santerre.xyz> | 2024-03-22 02:30:02 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-22 07:30:02 +0100 |
commit | fee33b45c2b635d83fa2ca0955ae453fe26374ea (patch) | |
tree | c4a54dc5d23704b2493d6b146a60729215d305ec /candle-metal-kernels/src/indexing.metal | |
parent | 6708870e633af636660c556c19703c38cbe2af8d (diff) | |
download | candle-fee33b45c2b635d83fa2ca0955ae453fe26374ea.tar.gz candle-fee33b45c2b635d83fa2ca0955ae453fe26374ea.tar.bz2 candle-fee33b45c2b635d83fa2ca0955ae453fe26374ea.zip |
Add support for strided index-select on Metal (#1909)
* initial implementation
* use correct index, but still not breaking like it should have...
* fix test
Diffstat (limited to 'candle-metal-kernels/src/indexing.metal')
-rw-r--r-- | candle-metal-kernels/src/indexing.metal | 41 |
1 files changed, 34 insertions, 7 deletions
diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 65491759..ad4a8605 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -1,20 +1,38 @@ #include <metal_stdlib> using namespace metal; +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + template<typename TYPENAME, typename INDEX_TYPENAME> METAL_FUNC void index( constant size_t &dst_size, constant size_t &left_size, constant size_t &src_dim_size, constant size_t &right_size, - constant size_t &ids_size, - const device TYPENAME *input, + constant size_t &ids_size, + constant bool &contiguous, + constant size_t *src_dims, + constant size_t *src_strides, + const device TYPENAME *input, const device INDEX_TYPENAME *input_ids, device TYPENAME *output, uint tid [[ thread_position_in_grid ]] ) { if (tid >= dst_size) { - return; + return; } const size_t id_i = (tid / right_size) % ids_size; const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); @@ -26,7 +44,8 @@ METAL_FUNC void index( // No need to check for zero we're only allowing unsized. */ const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; - output[tid] = input[src_i]; + const size_t strided_src_i = contiguous ? src_i : get_strided_index(src_i, src_dim_size, src_dims, src_strides); + output[tid] = input[strided_src_i]; } # define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \ @@ -36,12 +55,15 @@ kernel void NAME( \ constant size_t &src_dim_size, \ constant size_t &right_size, \ constant size_t &ids_size, \ + constant bool &contiguous, \ + constant size_t *src_dims, \ + constant size_t *src_strides, \ const device TYPENAME *input, \ const device INDEX_TYPENAME *input_ids, \ device TYPENAME *output, \ uint tid [[ thread_position_in_grid ]] \ ) { \ - index<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \ + index<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, contiguous, src_dims, src_strides, input, input_ids, output, tid); \ } @@ -165,10 +187,15 @@ kernel void NAME( \ } -INDEX_OP(is_u32_f32, uint, float) -INDEX_OP(is_u32_f16, uint, half) +INDEX_OP(is_u32_f32, uint32_t, float) +INDEX_OP(is_u32_f16, uint32_t, half) #if defined(__HAVE_BFLOAT__) INDEX_OP(is_u32_bf16, uint32_t, bfloat) +#endif + +INDEX_OP(is_u8_f32, uint8_t, float) +INDEX_OP(is_u8_f16, uint8_t, half) +#if defined(__HAVE_BFLOAT__) INDEX_OP(is_u8_bf16, uint8_t, bfloat) #endif |