summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src/cast.metal
diff options
context:
space:
mode:
Diffstat (limited to 'candle-metal-kernels/src/cast.metal')
-rw-r--r--candle-metal-kernels/src/cast.metal42
1 files changed, 38 insertions, 4 deletions
diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal
index e9ab17b1..9aead139 100644
--- a/candle-metal-kernels/src/cast.metal
+++ b/candle-metal-kernels/src/cast.metal
@@ -28,7 +28,7 @@ kernel void FN_NAME( \
if (tid >= dim) { \
return; \
} \
- output[tid] = RIGHT_TYPENAME(input[tid]); \
+ output[tid] = static_cast<RIGHT_TYPENAME>(input[tid]); \
} \
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
@@ -42,7 +42,34 @@ kernel void FN_NAME_STRIDED( \
if (tid >= dim) { \
return; \
} \
- output[tid] = RIGHT_TYPENAME(input[get_strided_index(tid, num_dims, dims, strides)]); \
+ output[tid] = static_cast<RIGHT_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)]); \
+} \
+
+#define CAST_THROUGH(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME, IR_TYPENAME) \
+kernel void FN_NAME( \
+ constant size_t &dim, \
+ device const LEFT_TYPENAME *input, \
+ device RIGHT_TYPENAME *output, \
+ uint tid [[ thread_position_in_grid ]] \
+) { \
+ if (tid >= dim) { \
+ return; \
+ } \
+ output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[tid])); \
+} \
+kernel void FN_NAME_STRIDED( \
+ constant size_t &dim, \
+ constant size_t &num_dims, \
+ constant size_t *dims, \
+ constant size_t *strides, \
+ device const LEFT_TYPENAME *input, \
+ device RIGHT_TYPENAME *output, \
+ uint tid [[ thread_position_in_grid ]] \
+) { \
+ if (tid >= dim) { \
+ return; \
+ } \
+ output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)])); \
} \
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
@@ -58,7 +85,14 @@ CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t)
CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
#endif
-#if __METAL_VERSION__ >= 310
+#if defined(__HAVE_BFLOAT__)
+CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t)
CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float)
+CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat)
+CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat)
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
-#endif
+
+CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float)
+CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float)
+CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
+#endif \ No newline at end of file