summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-27 20:17:35 +0200
committerGitHub <noreply@github.com>2024-04-27 20:17:35 +0200
commit96a48e5cc42b3c94d9d9687bb29987953df36db8 (patch)
tree4f1f391e6e6a8c1b865c4ab40e67aaf84dd21499 /candle-metal-kernels
parent6cf82fd7a34641601264ad1e0256ecadb7222474 (diff)
downloadcandle-96a48e5cc42b3c94d9d9687bb29987953df36db8.tar.gz
candle-96a48e5cc42b3c94d9d9687bb29987953df36db8.tar.bz2
candle-96a48e5cc42b3c94d9d9687bb29987953df36db8.zip
Add argsort. (#2132)
* Add the argsort cuda kernels. * CPU version of arg-sort. * Hook the cuda kernel + rework the cpu bits. * Add some dedicated test. * Working cuda kernel. * Metal kernel. * Metal adjustments. * Bugfix. * Use the fast rope in qwen. * Rework the expert selection in qwen.
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/lib.rs40
-rw-r--r--candle-metal-kernels/src/quantized.metal1
-rw-r--r--candle-metal-kernels/src/sort.metal97
3 files changed, 138 insertions, 0 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 10f942b4..8e075d5a 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -21,6 +21,7 @@ const REDUCE: &str = include_str!("reduce.metal");
const RANDOM: &str = include_str!("random.metal");
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
const QUANTIZED: &str = include_str!("quantized.metal");
+const SORT: &str = include_str!("sort.metal");
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Source {
@@ -35,6 +36,7 @@ pub enum Source {
Conv,
Random,
Quantized,
+ Sort,
}
pub mod copy2d {
@@ -197,6 +199,7 @@ impl Kernels {
Source::Conv => CONV,
Source::Random => RANDOM,
Source::Quantized => QUANTIZED,
+ Source::Sort => SORT,
Source::Mfa => panic!("Invalid lib"),
}
}
@@ -2048,5 +2051,42 @@ pub fn call_conv_transpose2d(
Ok(())
}
+#[allow(clippy::too_many_arguments)]
+pub fn call_arg_sort(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: &'static str,
+ nrows: usize,
+ ncols: usize,
+ ncols_pad: usize,
+ src: BufferOffset,
+ dst: &Buffer,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Sort, name)?;
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(encoder, (&src, dst, ncols as i64, ncols_pad as i64));
+
+ let thread_group_count = MTLSize {
+ width: 1,
+ height: nrows as u64,
+ depth: 1,
+ };
+ let thread_group_size = MTLSize {
+ width: ncols_pad as u64,
+ height: 1,
+ depth: 1,
+ };
+
+ encoder.use_resource(src.buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(dst, metal::MTLResourceUsage::Write);
+ encoder.set_threadgroup_memory_length(0, (ncols_pad * 4).max(16) as u64);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.end_encoding();
+ Ok(())
+}
+
#[cfg(test)]
mod tests;
diff --git a/candle-metal-kernels/src/quantized.metal b/candle-metal-kernels/src/quantized.metal
index 9aa7b502..fef6ac54 100644
--- a/candle-metal-kernels/src/quantized.metal
+++ b/candle-metal-kernels/src/quantized.metal
@@ -1,3 +1,4 @@
+// Imported from https://github.com/ggerganov/llama.cpp/blob/master/ggml-metal.metal
#include <metal_stdlib>
using namespace metal;
diff --git a/candle-metal-kernels/src/sort.metal b/candle-metal-kernels/src/sort.metal
new file mode 100644
index 00000000..d71ab822
--- /dev/null
+++ b/candle-metal-kernels/src/sort.metal
@@ -0,0 +1,97 @@
+// Imported from https://github.com/ggerganov/llama.cpp/blob/master/ggml-metal.metal
+#include <metal_stdlib>
+using namespace metal;
+
+#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
+#define SORT_ASC 1
+#define SORT_DESC 0
+
+template<int order, typename T>
+METAL_FUNC void argsort(
+ device const T * x,
+ device uint32_t * dst,
+ constant int64_t & ncols,
+ constant int64_t & ncols_pad,
+ threadgroup uint32_t * shared_values [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]]) {
+ int col = tpitg[0];
+ int row = tgpig[1];
+
+ if (col >= ncols_pad) return;
+
+ device const T * x_row = x + row * ncols;
+ threadgroup uint32_t * dst_row = shared_values;
+
+ // initialize indices
+ dst_row[col] = col;
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (int k = 2; k <= ncols_pad; k *= 2) {
+ for (int j = k / 2; j > 0; j /= 2) {
+ int ixj = col ^ j;
+ if (ixj > col) {
+ if ((col & k) == 0) {
+ if (dst_row[col] >= ncols ||
+ (dst_row[ixj] < ncols && (order == SORT_ASC ?
+ x_row[dst_row[col]] > x_row[dst_row[ixj]] :
+ x_row[dst_row[col]] < x_row[dst_row[ixj]]))
+ ) {
+ SWAP(dst_row[col], dst_row[ixj]);
+ }
+ } else {
+ if (dst_row[ixj] >= ncols ||
+ (dst_row[col] < ncols && (order == SORT_ASC ?
+ x_row[dst_row[col]] < x_row[dst_row[ixj]] :
+ x_row[dst_row[col]] > x_row[dst_row[ixj]]))
+ ) {
+ SWAP(dst_row[col], dst_row[ixj]);
+ }
+ }
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+ }
+
+ // copy the result to dst without the padding
+ if (col < ncols) {
+ dst[row * ncols + col] = dst_row[col];
+ }
+}
+
+#define ARGSORT(T, RUST_T) \
+kernel void asort_asc_##RUST_T( \
+ device const T * x, \
+ device uint32_t * dst, \
+ constant int64_t & ncols, \
+ constant int64_t & ncols_pad, \
+ threadgroup uint32_t * shared_values [[threadgroup(0)]], \
+ uint3 tgpig[[threadgroup_position_in_grid]], \
+ uint3 tpitg[[thread_position_in_threadgroup]] \
+) { \
+ argsort<SORT_ASC, T>(x, dst, ncols, ncols_pad, shared_values, tgpig, tpitg); \
+} \
+kernel void asort_desc_##RUST_T( \
+ device const T * x, \
+ device uint32_t * dst, \
+ constant int64_t & ncols, \
+ constant int64_t & ncols_pad, \
+ threadgroup uint32_t * shared_values [[threadgroup(0)]], \
+ uint3 tgpig[[threadgroup_position_in_grid]], \
+ uint3 tpitg[[thread_position_in_threadgroup]] \
+) { \
+ argsort<SORT_DESC, T>(x, dst, ncols, ncols_pad, shared_values, tgpig, tpitg); \
+} \
+
+ARGSORT(float, f32)
+ARGSORT(half, f16)
+ARGSORT(uint8_t, u8)
+ARGSORT(uint32_t, u32)
+
+#if __METAL_VERSION__ >= 220
+ARGSORT(int64_t, i64)
+#endif
+#if defined(__HAVE_BFLOAT__)
+ARGSORT(bfloat, bf16)
+#endif