summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src/sort.metal
blob: d71ab822342b27b98dc5984bfedf51236ccbdad8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
// Imported from https://github.com/ggerganov/llama.cpp/blob/master/ggml-metal.metal
#include <metal_stdlib>
using namespace metal;

#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
#define SORT_ASC 1
#define SORT_DESC 0

template<int order, typename T>
METAL_FUNC void argsort(
        device const T        * x,
        device       uint32_t * dst,
        constant     int64_t & ncols,
        constant     int64_t & ncols_pad,
        threadgroup uint32_t  * shared_values [[threadgroup(0)]],
        uint3 tgpig[[threadgroup_position_in_grid]],
        uint3 tpitg[[thread_position_in_threadgroup]]) {
    int col = tpitg[0];
    int row = tgpig[1];

    if (col >= ncols_pad) return;

    device const T        * x_row   = x + row * ncols;
    threadgroup uint32_t  * dst_row = shared_values;

    // initialize indices
    dst_row[col] = col;

    threadgroup_barrier(mem_flags::mem_threadgroup);

    for (int k = 2; k <= ncols_pad; k *= 2) {
        for (int j = k / 2; j > 0; j /= 2) {
            int ixj = col ^ j;
            if (ixj > col) {
                if ((col & k) == 0) {
                    if (dst_row[col] >= ncols ||
                        (dst_row[ixj] < ncols && (order == SORT_ASC ?
                            x_row[dst_row[col]] > x_row[dst_row[ixj]] :
                            x_row[dst_row[col]] < x_row[dst_row[ixj]]))
                    ) {
                        SWAP(dst_row[col], dst_row[ixj]);
                    }
                } else {
                    if (dst_row[ixj] >= ncols ||
                        (dst_row[col] < ncols && (order == SORT_ASC ?
                            x_row[dst_row[col]] < x_row[dst_row[ixj]] :
                            x_row[dst_row[col]] > x_row[dst_row[ixj]]))
                    ) {
                        SWAP(dst_row[col], dst_row[ixj]);
                    }
                }
            }
            threadgroup_barrier(mem_flags::mem_threadgroup);
        }
    }

    // copy the result to dst without the padding
    if (col < ncols) {
        dst[row * ncols + col] = dst_row[col];
    }
}

#define ARGSORT(T, RUST_T) \
kernel void asort_asc_##RUST_T( \
    device const T        * x, \
    device       uint32_t * dst, \
    constant     int64_t & ncols, \
    constant     int64_t & ncols_pad, \
    threadgroup uint32_t  * shared_values [[threadgroup(0)]], \
    uint3 tgpig[[threadgroup_position_in_grid]], \
    uint3 tpitg[[thread_position_in_threadgroup]] \
) {  \
    argsort<SORT_ASC, T>(x, dst, ncols, ncols_pad, shared_values, tgpig, tpitg); \
} \
kernel void asort_desc_##RUST_T( \
    device const T        * x, \
    device       uint32_t * dst, \
    constant     int64_t & ncols, \
    constant     int64_t & ncols_pad, \
    threadgroup uint32_t  * shared_values [[threadgroup(0)]], \
    uint3 tgpig[[threadgroup_position_in_grid]], \
    uint3 tpitg[[thread_position_in_threadgroup]] \
) {  \
    argsort<SORT_DESC, T>(x, dst, ncols, ncols_pad, shared_values, tgpig, tpitg); \
} \

ARGSORT(float, f32)
ARGSORT(half, f16)
ARGSORT(uint8_t, u8)
ARGSORT(uint32_t, u32)

#if __METAL_VERSION__ >= 220
ARGSORT(int64_t, i64)
#endif
#if defined(__HAVE_BFLOAT__)
ARGSORT(bfloat, bf16)
#endif