diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-05-22 09:44:52 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-22 09:44:52 +0200 |
commit | 72e7ca529a3c243bef844f822a9668eaf8e36807 (patch) | |
tree | 5353f8bde979d9b6039deee3c10c2ada41135d36 /candle-metal-kernels | |
parent | 7ff921c5385e1f08dc534b67a969cd06b91714d5 (diff) | |
download | candle-72e7ca529a3c243bef844f822a9668eaf8e36807.tar.gz candle-72e7ca529a3c243bef844f822a9668eaf8e36807.tar.bz2 candle-72e7ca529a3c243bef844f822a9668eaf8e36807.zip |
Add some missing where-cond kernels for metal. (#2203)
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/ternary.metal | 31 |
1 files changed, 17 insertions, 14 deletions
diff --git a/candle-metal-kernels/src/ternary.metal b/candle-metal-kernels/src/ternary.metal index 7b3b8ca9..fe04f237 100644 --- a/candle-metal-kernels/src/ternary.metal +++ b/candle-metal-kernels/src/ternary.metal @@ -1,5 +1,4 @@ #include <metal_stdlib> -# using namespace metal; METAL_FUNC uint get_strided_index( @@ -57,27 +56,31 @@ kernel void FN_NAME( where_cond<T, ID>(numel, num_dims, dims, strides, strides_t, strides_f, ids, t, f, out, i); \ } \ -// WHERE_OP(float, int64_t, where_i64_f32) -// WHERE_OP(double, int64_t, where_i64_f64) -// WHERE_OP(uint8_t, int64_t, where_i64_u8) -// WHERE_OP(uint32_t, int64_t, where_i64_u32) -// WHERE_OP(int64_t, int64_t, where_i64_i64) -// -// WHERE_OP(float, uint32_t, where_u32_f32) -// WHERE_OP(double, uint32_t, where_u32_f64) -// WHERE_OP(uint8_t, uint32_t, where_u32_u8) -// WHERE_OP(uint32_t, uint32_t, where_u32_u32) -// WHERE_OP(int64_t, uint32_t, where_u32_i64) +WHERE_OP(half, uint32_t, where_u32_f16) +WHERE_OP(float, uint32_t, where_u32_f32) +WHERE_OP(uint8_t, uint32_t, where_u32_u8) +WHERE_OP(uint32_t, uint32_t, where_u32_u32) -WHERE_OP(float, uint8_t, where_u8_f32) WHERE_OP(half, uint8_t, where_u8_f16) +WHERE_OP(float, uint8_t, where_u8_f32) WHERE_OP(uint8_t, uint8_t, where_u8_u8) WHERE_OP(uint32_t, uint8_t, where_u8_u32) #if __METAL_VERSION__ >= 220 WHERE_OP(int64_t, uint8_t, where_u8_i64) +WHERE_OP(int64_t, uint32_t, where_u32_i64) + +WHERE_OP(half, int64_t, where_i64_f16) +WHERE_OP(float, int64_t, where_i64_f32) +WHERE_OP(uint8_t, int64_t, where_i64_u8) +WHERE_OP(uint32_t, int64_t, where_i64_u32) +WHERE_OP(int64_t, int64_t, where_i64_i64) +#if defined(__HAVE_BFLOAT__) +WHERE_OP(bfloat, int64_t, where_i64_bf16) +#endif #endif #if defined(__HAVE_BFLOAT__) WHERE_OP(bfloat, uint8_t, where_u8_bf16) -#endif
\ No newline at end of file +WHERE_OP(bfloat, uint32_t, where_u32_bf16) +#endif |