summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorzachcp <zachcp@users.noreply.github.com>2024-11-30 17:18:07 -0500
committerGitHub <noreply@github.com>2024-11-30 23:18:07 +0100
commitdba7a9c93e4c84c8197e8a5b56f40adcf2650bde (patch)
treec498e4854c3333c7f5ad083d0874601504835c70
parentb52c2c60508325431df5e05eca9801060fdbcc1c (diff)
downloadcandle-dba7a9c93e4c84c8197e8a5b56f40adcf2650bde.tar.gz
candle-dba7a9c93e4c84c8197e8a5b56f40adcf2650bde.tar.bz2
candle-dba7a9c93e4c84c8197e8a5b56f40adcf2650bde.zip
add u32 - U32 gather (#2653)
-rw-r--r--candle-core/src/metal_backend/mod.rs1
-rw-r--r--candle-metal-kernels/src/indexing.metal159
2 files changed, 81 insertions, 79 deletions
diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs
index 47f54c8d..e8159f46 100644
--- a/candle-core/src/metal_backend/mod.rs
+++ b/candle-core/src/metal_backend/mod.rs
@@ -1244,6 +1244,7 @@ impl BackendStorage for MetalStorage {
(DType::U32, DType::F32) => "gather_u32_f32",
(DType::U32, DType::F16) => "gather_u32_f16",
(DType::U32, DType::BF16) => "gather_u32_bf16",
+ (DType::U32, DType::U32) => "gather_u32_u32",
(left, right) => crate::bail!("Metal gather {left:?} {right:?} not implemented"),
};
let command_buffer = self.device.command_buffer()?;
diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal
index c14f2c1f..2594689c 100644
--- a/candle-metal-kernels/src/indexing.metal
+++ b/candle-metal-kernels/src/indexing.metal
@@ -17,33 +17,33 @@ METAL_FUNC uint get_strided_index(
}
template<typename TYPENAME, typename INDEX_TYPENAME>
-METAL_FUNC void index(
- constant size_t &dst_size,
- constant size_t &left_size,
- constant size_t &src_dim_size,
- constant size_t &right_size,
+METAL_FUNC void index(
+ constant size_t &dst_size,
+ constant size_t &left_size,
+ constant size_t &src_dim_size,
+ constant size_t &right_size,
constant size_t &ids_size,
constant bool &contiguous,
constant size_t *src_dims,
constant size_t *src_strides,
const device TYPENAME *input,
- const device INDEX_TYPENAME *input_ids,
- device TYPENAME *output,
- uint tid [[ thread_position_in_grid ]]
-) {
- if (tid >= dst_size) {
+ const device INDEX_TYPENAME *input_ids,
+ device TYPENAME *output,
+ uint tid [[ thread_position_in_grid ]]
+) {
+ if (tid >= dst_size) {
return;
- }
- const size_t id_i = (tid / right_size) % ids_size;
- const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1));
- const size_t right_rank_i = tid % right_size;
- const size_t left_rank_i = tid / right_size / ids_size;
- /*
- // Force prevent out of bounds indexing
- // since there doesn't seem to be a good way to force crash
- // No need to check for zero we're only allowing unsized.
- */
- const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i;
+ }
+ const size_t id_i = (tid / right_size) % ids_size;
+ const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1));
+ const size_t right_rank_i = tid % right_size;
+ const size_t left_rank_i = tid / right_size / ids_size;
+ /*
+ // Force prevent out of bounds indexing
+ // since there doesn't seem to be a good way to force crash
+ // No need to check for zero we're only allowing unsized.
+ */
+ const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i;
const size_t strided_src_i = contiguous ? src_i : get_strided_index(src_i, src_dim_size, src_dims, src_strides);
output[tid] = input[strided_src_i];
}
@@ -68,25 +68,25 @@ kernel void NAME( \
template<typename TYPENAME, typename INDEX_TYPENAME>
-METAL_FUNC void gather(
- constant size_t &dst_size,
- constant size_t &left_size,
- constant size_t &src_dim_size,
- constant size_t &right_size,
- constant size_t &ids_size,
- const device TYPENAME *input,
- const device INDEX_TYPENAME *input_ids,
- device TYPENAME *output,
- uint tid [[ thread_position_in_grid ]]
-) {
- if (tid >= dst_size) {
- return;
- }
- const INDEX_TYPENAME input_i = input_ids[tid];
- const size_t right_rank_i = tid % right_size;
- const size_t left_rank_i = tid / right_size / ids_size;
- const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i;
- output[tid] = input[src_i];
+METAL_FUNC void gather(
+ constant size_t &dst_size,
+ constant size_t &left_size,
+ constant size_t &src_dim_size,
+ constant size_t &right_size,
+ constant size_t &ids_size,
+ const device TYPENAME *input,
+ const device INDEX_TYPENAME *input_ids,
+ device TYPENAME *output,
+ uint tid [[ thread_position_in_grid ]]
+) {
+ if (tid >= dst_size) {
+ return;
+ }
+ const INDEX_TYPENAME input_i = input_ids[tid];
+ const size_t right_rank_i = tid % right_size;
+ const size_t left_rank_i = tid / right_size / ids_size;
+ const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i;
+ output[tid] = input[src_i];
}
# define GATHER_OP(NAME, INDEX_TYPENAME, TYPENAME) \
@@ -105,27 +105,27 @@ kernel void NAME( \
}
template<typename TYPENAME, typename INDEX_TYPENAME>
-METAL_FUNC void scatter_add(
- constant size_t &dst_size,
- constant size_t &left_size,
- constant size_t &src_dim_size,
- constant size_t &right_size,
- constant size_t &dst_dim_size,
- const device TYPENAME *input,
- const device INDEX_TYPENAME *input_ids,
- device TYPENAME *output,
- uint tid [[ thread_position_in_grid ]]
-) {
- if (tid >= dst_size) {
- return;
- }
- const size_t right_rank_i = tid % right_size;
- const size_t left_rank_i = tid / right_size;
+METAL_FUNC void scatter_add(
+ constant size_t &dst_size,
+ constant size_t &left_size,
+ constant size_t &src_dim_size,
+ constant size_t &right_size,
+ constant size_t &dst_dim_size,
+ const device TYPENAME *input,
+ const device INDEX_TYPENAME *input_ids,
+ device TYPENAME *output,
+ uint tid [[ thread_position_in_grid ]]
+) {
+ if (tid >= dst_size) {
+ return;
+ }
+ const size_t right_rank_i = tid % right_size;
+ const size_t left_rank_i = tid / right_size;
for (unsigned int j = 0; j < src_dim_size; ++j) {
- const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
+ const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
const INDEX_TYPENAME idx = input_ids[src_i];
- const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
- output[dst_i] += input[src_i];
+ const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
+ output[dst_i] += input[src_i];
}
}
@@ -145,28 +145,28 @@ kernel void NAME( \
}
template<typename TYPENAME, typename INDEX_TYPENAME>
-METAL_FUNC void index_add(
- constant size_t &dst_size,
- constant size_t &left_size,
- constant size_t &src_dim_size,
- constant size_t &right_size,
- constant size_t &dst_dim_size,
- constant size_t &ids_dim_size,
- const device TYPENAME *input,
- const device INDEX_TYPENAME *input_ids,
- device TYPENAME *output,
- uint tid [[ thread_position_in_grid ]]
-) {
- if (tid >= dst_size) {
- return;
- }
- const size_t right_rank_i = tid % right_size;
- const size_t left_rank_i = tid / right_size;
+METAL_FUNC void index_add(
+ constant size_t &dst_size,
+ constant size_t &left_size,
+ constant size_t &src_dim_size,
+ constant size_t &right_size,
+ constant size_t &dst_dim_size,
+ constant size_t &ids_dim_size,
+ const device TYPENAME *input,
+ const device INDEX_TYPENAME *input_ids,
+ device TYPENAME *output,
+ uint tid [[ thread_position_in_grid ]]
+) {
+ if (tid >= dst_size) {
+ return;
+ }
+ const size_t right_rank_i = tid % right_size;
+ const size_t left_rank_i = tid / right_size;
for (unsigned int j = 0; j < ids_dim_size; ++j) {
const INDEX_TYPENAME idx = input_ids[j];
- const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
- const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
- output[dst_i] += input[src_i];
+ const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
+ const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
+ output[dst_i] += input[src_i];
}
}
@@ -214,6 +214,7 @@ GATHER_OP(gather_u32_f16, uint, half)
#if defined(__HAVE_BFLOAT__)
GATHER_OP(gather_u32_bf16, uint, bfloat)
#endif
+GATHER_OP(gather_u32_u32, uint, uint)
SCATTER_ADD_OP(sa_u32_f32, uint32_t, float)
SCATTER_ADD_OP(sa_u8_f32, uint8_t, float)