summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-05-22 09:44:52 +0200
committerGitHub <noreply@github.com>2024-05-22 09:44:52 +0200
commit72e7ca529a3c243bef844f822a9668eaf8e36807 (patch)
tree5353f8bde979d9b6039deee3c10c2ada41135d36 /candle-metal-kernels
parent7ff921c5385e1f08dc534b67a969cd06b91714d5 (diff)
downloadcandle-72e7ca529a3c243bef844f822a9668eaf8e36807.tar.gz
candle-72e7ca529a3c243bef844f822a9668eaf8e36807.tar.bz2
candle-72e7ca529a3c243bef844f822a9668eaf8e36807.zip
Add some missing where-cond kernels for metal. (#2203)
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/ternary.metal31
1 files changed, 17 insertions, 14 deletions
diff --git a/candle-metal-kernels/src/ternary.metal b/candle-metal-kernels/src/ternary.metal
index 7b3b8ca9..fe04f237 100644
--- a/candle-metal-kernels/src/ternary.metal
+++ b/candle-metal-kernels/src/ternary.metal
@@ -1,5 +1,4 @@
#include <metal_stdlib>
-#
using namespace metal;
METAL_FUNC uint get_strided_index(
@@ -57,27 +56,31 @@ kernel void FN_NAME(
where_cond<T, ID>(numel, num_dims, dims, strides, strides_t, strides_f, ids, t, f, out, i); \
} \
-// WHERE_OP(float, int64_t, where_i64_f32)
-// WHERE_OP(double, int64_t, where_i64_f64)
-// WHERE_OP(uint8_t, int64_t, where_i64_u8)
-// WHERE_OP(uint32_t, int64_t, where_i64_u32)
-// WHERE_OP(int64_t, int64_t, where_i64_i64)
-//
-// WHERE_OP(float, uint32_t, where_u32_f32)
-// WHERE_OP(double, uint32_t, where_u32_f64)
-// WHERE_OP(uint8_t, uint32_t, where_u32_u8)
-// WHERE_OP(uint32_t, uint32_t, where_u32_u32)
-// WHERE_OP(int64_t, uint32_t, where_u32_i64)
+WHERE_OP(half, uint32_t, where_u32_f16)
+WHERE_OP(float, uint32_t, where_u32_f32)
+WHERE_OP(uint8_t, uint32_t, where_u32_u8)
+WHERE_OP(uint32_t, uint32_t, where_u32_u32)
-WHERE_OP(float, uint8_t, where_u8_f32)
WHERE_OP(half, uint8_t, where_u8_f16)
+WHERE_OP(float, uint8_t, where_u8_f32)
WHERE_OP(uint8_t, uint8_t, where_u8_u8)
WHERE_OP(uint32_t, uint8_t, where_u8_u32)
#if __METAL_VERSION__ >= 220
WHERE_OP(int64_t, uint8_t, where_u8_i64)
+WHERE_OP(int64_t, uint32_t, where_u32_i64)
+
+WHERE_OP(half, int64_t, where_i64_f16)
+WHERE_OP(float, int64_t, where_i64_f32)
+WHERE_OP(uint8_t, int64_t, where_i64_u8)
+WHERE_OP(uint32_t, int64_t, where_i64_u32)
+WHERE_OP(int64_t, int64_t, where_i64_i64)
+#if defined(__HAVE_BFLOAT__)
+WHERE_OP(bfloat, int64_t, where_i64_bf16)
+#endif
#endif
#if defined(__HAVE_BFLOAT__)
WHERE_OP(bfloat, uint8_t, where_u8_bf16)
-#endif \ No newline at end of file
+WHERE_OP(bfloat, uint32_t, where_u32_bf16)
+#endif