summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src/indexing.metal
diff options
context:
space:
mode:
authorThomas Santerre <thomas@santerre.xyz>2024-03-22 02:30:02 -0400
committerGitHub <noreply@github.com>2024-03-22 07:30:02 +0100
commitfee33b45c2b635d83fa2ca0955ae453fe26374ea (patch)
treec4a54dc5d23704b2493d6b146a60729215d305ec /candle-metal-kernels/src/indexing.metal
parent6708870e633af636660c556c19703c38cbe2af8d (diff)
downloadcandle-fee33b45c2b635d83fa2ca0955ae453fe26374ea.tar.gz
candle-fee33b45c2b635d83fa2ca0955ae453fe26374ea.tar.bz2
candle-fee33b45c2b635d83fa2ca0955ae453fe26374ea.zip
Add support for strided index-select on Metal (#1909)
* initial implementation * use correct index, but still not breaking like it should have... * fix test
Diffstat (limited to 'candle-metal-kernels/src/indexing.metal')
-rw-r--r--candle-metal-kernels/src/indexing.metal41
1 files changed, 34 insertions, 7 deletions
diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal
index 65491759..ad4a8605 100644
--- a/candle-metal-kernels/src/indexing.metal
+++ b/candle-metal-kernels/src/indexing.metal
@@ -1,20 +1,38 @@
#include <metal_stdlib>
using namespace metal;
+METAL_FUNC uint get_strided_index(
+ uint idx,
+ constant size_t &num_dims,
+ constant size_t *dims,
+ constant size_t *strides
+) {
+ uint strided_i = 0;
+ for (uint d = 0; d < num_dims; d++) {
+ uint dim_idx = num_dims - 1 - d;
+ strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
+ idx /= dims[dim_idx];
+ }
+ return strided_i;
+}
+
template<typename TYPENAME, typename INDEX_TYPENAME>
METAL_FUNC void index(
constant size_t &dst_size,
constant size_t &left_size,
constant size_t &src_dim_size,
constant size_t &right_size,
- constant size_t &ids_size,
- const device TYPENAME *input,
+ constant size_t &ids_size,
+ constant bool &contiguous,
+ constant size_t *src_dims,
+ constant size_t *src_strides,
+ const device TYPENAME *input,
const device INDEX_TYPENAME *input_ids,
device TYPENAME *output,
uint tid [[ thread_position_in_grid ]]
) {
if (tid >= dst_size) {
- return;
+ return;
}
const size_t id_i = (tid / right_size) % ids_size;
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1));
@@ -26,7 +44,8 @@ METAL_FUNC void index(
// No need to check for zero we're only allowing unsized.
*/
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i;
- output[tid] = input[src_i];
+ const size_t strided_src_i = contiguous ? src_i : get_strided_index(src_i, src_dim_size, src_dims, src_strides);
+ output[tid] = input[strided_src_i];
}
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
@@ -36,12 +55,15 @@ kernel void NAME( \
constant size_t &src_dim_size, \
constant size_t &right_size, \
constant size_t &ids_size, \
+ constant bool &contiguous, \
+ constant size_t *src_dims, \
+ constant size_t *src_strides, \
const device TYPENAME *input, \
const device INDEX_TYPENAME *input_ids, \
device TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
- index<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \
+ index<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, contiguous, src_dims, src_strides, input, input_ids, output, tid); \
}
@@ -165,10 +187,15 @@ kernel void NAME( \
}
-INDEX_OP(is_u32_f32, uint, float)
-INDEX_OP(is_u32_f16, uint, half)
+INDEX_OP(is_u32_f32, uint32_t, float)
+INDEX_OP(is_u32_f16, uint32_t, half)
#if defined(__HAVE_BFLOAT__)
INDEX_OP(is_u32_bf16, uint32_t, bfloat)
+#endif
+
+INDEX_OP(is_u8_f32, uint8_t, float)
+INDEX_OP(is_u8_f16, uint8_t, half)
+#if defined(__HAVE_BFLOAT__)
INDEX_OP(is_u8_bf16, uint8_t, bfloat)
#endif