summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src/sort.metal
diff options
context:
space:
mode:
Diffstat (limited to 'candle-metal-kernels/src/sort.metal')
-rw-r--r--candle-metal-kernels/src/sort.metal97
1 files changed, 97 insertions, 0 deletions
diff --git a/candle-metal-kernels/src/sort.metal b/candle-metal-kernels/src/sort.metal
new file mode 100644
index 00000000..d71ab822
--- /dev/null
+++ b/candle-metal-kernels/src/sort.metal
@@ -0,0 +1,97 @@
+// Imported from https://github.com/ggerganov/llama.cpp/blob/master/ggml-metal.metal
+#include <metal_stdlib>
+using namespace metal;
+
+#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
+#define SORT_ASC 1
+#define SORT_DESC 0
+
+template<int order, typename T>
+METAL_FUNC void argsort(
+ device const T * x,
+ device uint32_t * dst,
+ constant int64_t & ncols,
+ constant int64_t & ncols_pad,
+ threadgroup uint32_t * shared_values [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]]) {
+ int col = tpitg[0];
+ int row = tgpig[1];
+
+ if (col >= ncols_pad) return;
+
+ device const T * x_row = x + row * ncols;
+ threadgroup uint32_t * dst_row = shared_values;
+
+ // initialize indices
+ dst_row[col] = col;
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (int k = 2; k <= ncols_pad; k *= 2) {
+ for (int j = k / 2; j > 0; j /= 2) {
+ int ixj = col ^ j;
+ if (ixj > col) {
+ if ((col & k) == 0) {
+ if (dst_row[col] >= ncols ||
+ (dst_row[ixj] < ncols && (order == SORT_ASC ?
+ x_row[dst_row[col]] > x_row[dst_row[ixj]] :
+ x_row[dst_row[col]] < x_row[dst_row[ixj]]))
+ ) {
+ SWAP(dst_row[col], dst_row[ixj]);
+ }
+ } else {
+ if (dst_row[ixj] >= ncols ||
+ (dst_row[col] < ncols && (order == SORT_ASC ?
+ x_row[dst_row[col]] < x_row[dst_row[ixj]] :
+ x_row[dst_row[col]] > x_row[dst_row[ixj]]))
+ ) {
+ SWAP(dst_row[col], dst_row[ixj]);
+ }
+ }
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+ }
+
+ // copy the result to dst without the padding
+ if (col < ncols) {
+ dst[row * ncols + col] = dst_row[col];
+ }
+}
+
+#define ARGSORT(T, RUST_T) \
+kernel void asort_asc_##RUST_T( \
+ device const T * x, \
+ device uint32_t * dst, \
+ constant int64_t & ncols, \
+ constant int64_t & ncols_pad, \
+ threadgroup uint32_t * shared_values [[threadgroup(0)]], \
+ uint3 tgpig[[threadgroup_position_in_grid]], \
+ uint3 tpitg[[thread_position_in_threadgroup]] \
+) { \
+ argsort<SORT_ASC, T>(x, dst, ncols, ncols_pad, shared_values, tgpig, tpitg); \
+} \
+kernel void asort_desc_##RUST_T( \
+ device const T * x, \
+ device uint32_t * dst, \
+ constant int64_t & ncols, \
+ constant int64_t & ncols_pad, \
+ threadgroup uint32_t * shared_values [[threadgroup(0)]], \
+ uint3 tgpig[[threadgroup_position_in_grid]], \
+ uint3 tpitg[[thread_position_in_threadgroup]] \
+) { \
+ argsort<SORT_DESC, T>(x, dst, ncols, ncols_pad, shared_values, tgpig, tpitg); \
+} \
+
+ARGSORT(float, f32)
+ARGSORT(half, f16)
+ARGSORT(uint8_t, u8)
+ARGSORT(uint32_t, u32)
+
+#if __METAL_VERSION__ >= 220
+ARGSORT(int64_t, i64)
+#endif
+#if defined(__HAVE_BFLOAT__)
+ARGSORT(bfloat, bf16)
+#endif