summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src/fill.metal
blob: 35c3fe7ab2d2bd511327391a6ebfdd29ee7a1254 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#include <metal_stdlib>

using namespace metal;

template<typename T> METAL_FUNC void fill_with(
    device T *out,
    constant float &value,
    constant size_t &numel,
    uint tid [[thread_position_in_grid]]
) {
    if (tid >= numel) {
        return;
    }
    out[tid] = static_cast<T>(value);
}

#define FILL_OP(NAME, T)                                \
kernel void fill_##NAME(                                \
    device T *out,                                      \
    constant float &value,                              \
    constant size_t &numel,                              \
    uint tid [[thread_position_in_grid]]                \
) {                                                     \
    fill_with<T>(out, value, numel, tid);              \
}                                                       \


#define FILL_OPS(NAME, T) \
FILL_OP(NAME, T)          \

FILL_OPS(u8, uchar)
FILL_OPS(u32, uint)
FILL_OPS(i64, long)
FILL_OPS(f16, half)
FILL_OPS(f32, float)

#if __METAL_VERSION__ >= 310
FILL_OPS(bf16, bfloat)
#endif