summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src/cast.metal
diff options
context:
space:
mode:
Diffstat (limited to 'candle-metal-kernels/src/cast.metal')
-rw-r--r--candle-metal-kernels/src/cast.metal18
1 files changed, 11 insertions, 7 deletions
diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal
index d1788253..4398e9d4 100644
--- a/candle-metal-kernels/src/cast.metal
+++ b/candle-metal-kernels/src/cast.metal
@@ -23,12 +23,12 @@ kernel void FN_NAME( \
constant size_t &dim, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
- uint thread_position_in_grid [[ thread_position_in_grid ]] \
+ uint tid [[ thread_position_in_grid ]] \
) { \
- if (thread_position_in_grid >= dim) { \
+ if (tid >= dim) { \
return; \
} \
- output[thread_position_in_grid] = RIGHT_TYPENAME(input[thread_position_in_grid]); \
+ output[tid] = RIGHT_TYPENAME(input[tid]); \
} \
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
@@ -37,15 +37,19 @@ kernel void FN_NAME_STRIDED( \
constant size_t *strides, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
- uint i [[ thread_position_in_grid ]] \
+ uint tid [[ thread_position_in_grid ]] \
) { \
- if (i >= dim) { \
+ if (tid >= dim) { \
return; \
} \
- output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \
+ output[tid] = RIGHT_TYPENAME(input[get_strided_index(tid, num_dims, dims, strides)]); \
} \
-CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float)
+CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
+CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
+CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
+CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
+CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
#if __METAL_VERSION__ >= 310
#endif