summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorzachcp <zachcp@users.noreply.github.com>2024-12-01 12:39:38 -0500
committerGitHub <noreply@github.com>2024-12-01 18:39:38 +0100
commit6f715f92564c10426c5565cd30ece25aee8d72ac (patch)
tree40d4ab40c7d5f52c41d8d4b6e8b73a758947a0bb
parentdba7a9c93e4c84c8197e8a5b56f40adcf2650bde (diff)
downloadcandle-6f715f92564c10426c5565cd30ece25aee8d72ac.tar.gz
candle-6f715f92564c10426c5565cd30ece25aee8d72ac.tar.bz2
candle-6f715f92564c10426c5565cd30ece25aee8d72ac.zip
add scatter add (#2656)
-rw-r--r--candle-core/src/metal_backend/mod.rs1
-rw-r--r--candle-metal-kernels/src/indexing.metal1
2 files changed, 2 insertions, 0 deletions
diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs
index e8159f46..bffba50d 100644
--- a/candle-core/src/metal_backend/mod.rs
+++ b/candle-core/src/metal_backend/mod.rs
@@ -1284,6 +1284,7 @@ impl BackendStorage for MetalStorage {
(DType::U8, DType::F32) => "sa_u8_f32",
(DType::U8, DType::F16) => "sa_u8_f16",
(DType::U8, DType::BF16) => "sa_u8_bf16",
+ (DType::U32, DType::U32) => "sa_u32_u32",
(DType::U32, DType::F32) => "sa_u32_f32",
(DType::U32, DType::F16) => "sa_u32_f16",
(DType::U32, DType::BF16) => "sa_u32_bf16",
diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal
index 2594689c..7509b628 100644
--- a/candle-metal-kernels/src/indexing.metal
+++ b/candle-metal-kernels/src/indexing.metal
@@ -219,6 +219,7 @@ GATHER_OP(gather_u32_u32, uint, uint)
SCATTER_ADD_OP(sa_u32_f32, uint32_t, float)
SCATTER_ADD_OP(sa_u8_f32, uint8_t, float)
SCATTER_ADD_OP(sa_i64_f32, int64_t, float)
+SCATTER_ADD_OP(sa_u32_u32, uint32_t, uint32_t)
SCATTER_ADD_OP(sa_u32_f16, uint32_t, half)
SCATTER_ADD_OP(sa_u8_f16, uint8_t, half)
SCATTER_ADD_OP(sa_i64_f16, int64_t, half)