1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
|
// Imported from https://github.com/ggerganov/llama.cpp/blob/master/ggml-metal.metal
#include <metal_stdlib>
using namespace metal;
#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
#define SORT_ASC 1
#define SORT_DESC 0
template<int order, typename T>
METAL_FUNC void argsort(
device const T * x,
device uint32_t * dst,
constant int64_t & ncols,
constant int64_t & ncols_pad,
threadgroup uint32_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]]) {
int col = tpitg[0];
int row = tgpig[1];
if (col >= ncols_pad) return;
device const T * x_row = x + row * ncols;
threadgroup uint32_t * dst_row = shared_values;
// initialize indices
dst_row[col] = col;
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int k = 2; k <= ncols_pad; k *= 2) {
for (int j = k / 2; j > 0; j /= 2) {
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
if (dst_row[col] >= ncols ||
(dst_row[ixj] < ncols && (order == SORT_ASC ?
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
) {
SWAP(dst_row[col], dst_row[ixj]);
}
} else {
if (dst_row[ixj] >= ncols ||
(dst_row[col] < ncols && (order == SORT_ASC ?
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
) {
SWAP(dst_row[col], dst_row[ixj]);
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
}
// copy the result to dst without the padding
if (col < ncols) {
dst[row * ncols + col] = dst_row[col];
}
}
#define ARGSORT(T, RUST_T) \
kernel void asort_asc_##RUST_T( \
device const T * x, \
device uint32_t * dst, \
constant int64_t & ncols, \
constant int64_t & ncols_pad, \
threadgroup uint32_t * shared_values [[threadgroup(0)]], \
uint3 tgpig[[threadgroup_position_in_grid]], \
uint3 tpitg[[thread_position_in_threadgroup]] \
) { \
argsort<SORT_ASC, T>(x, dst, ncols, ncols_pad, shared_values, tgpig, tpitg); \
} \
kernel void asort_desc_##RUST_T( \
device const T * x, \
device uint32_t * dst, \
constant int64_t & ncols, \
constant int64_t & ncols_pad, \
threadgroup uint32_t * shared_values [[threadgroup(0)]], \
uint3 tgpig[[threadgroup_position_in_grid]], \
uint3 tpitg[[thread_position_in_threadgroup]] \
) { \
argsort<SORT_DESC, T>(x, dst, ncols, ncols_pad, shared_values, tgpig, tpitg); \
} \
ARGSORT(float, f32)
ARGSORT(half, f16)
ARGSORT(uint8_t, u8)
ARGSORT(uint32_t, u32)
#if __METAL_VERSION__ >= 220
ARGSORT(int64_t, i64)
#endif
#if defined(__HAVE_BFLOAT__)
ARGSORT(bfloat, bf16)
#endif
|