summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src/fill.metal
diff options
context:
space:
mode:
Diffstat (limited to 'candle-metal-kernels/src/fill.metal')
-rw-r--r--candle-metal-kernels/src/fill.metal39
1 files changed, 39 insertions, 0 deletions
diff --git a/candle-metal-kernels/src/fill.metal b/candle-metal-kernels/src/fill.metal
new file mode 100644
index 00000000..35c3fe7a
--- /dev/null
+++ b/candle-metal-kernels/src/fill.metal
@@ -0,0 +1,39 @@
+#include <metal_stdlib>
+
+using namespace metal;
+
+template<typename T> METAL_FUNC void fill_with(
+ device T *out,
+ constant float &value,
+ constant size_t &numel,
+ uint tid [[thread_position_in_grid]]
+) {
+ if (tid >= numel) {
+ return;
+ }
+ out[tid] = static_cast<T>(value);
+}
+
+#define FILL_OP(NAME, T) \
+kernel void fill_##NAME( \
+ device T *out, \
+ constant float &value, \
+ constant size_t &numel, \
+ uint tid [[thread_position_in_grid]] \
+) { \
+ fill_with<T>(out, value, numel, tid); \
+} \
+
+
+#define FILL_OPS(NAME, T) \
+FILL_OP(NAME, T) \
+
+FILL_OPS(u8, uchar)
+FILL_OPS(u32, uint)
+FILL_OPS(i64, long)
+FILL_OPS(f16, half)
+FILL_OPS(f32, float)
+
+#if __METAL_VERSION__ >= 310
+FILL_OPS(bf16, bfloat)
+#endif