diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-04-27 20:17:35 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-27 20:17:35 +0200 |
commit | 96a48e5cc42b3c94d9d9687bb29987953df36db8 (patch) | |
tree | 4f1f391e6e6a8c1b865c4ab40e67aaf84dd21499 /candle-metal-kernels | |
parent | 6cf82fd7a34641601264ad1e0256ecadb7222474 (diff) | |
download | candle-96a48e5cc42b3c94d9d9687bb29987953df36db8.tar.gz candle-96a48e5cc42b3c94d9d9687bb29987953df36db8.tar.bz2 candle-96a48e5cc42b3c94d9d9687bb29987953df36db8.zip |
Add argsort. (#2132)
* Add the argsort cuda kernels.
* CPU version of arg-sort.
* Hook the cuda kernel + rework the cpu bits.
* Add some dedicated test.
* Working cuda kernel.
* Metal kernel.
* Metal adjustments.
* Bugfix.
* Use the fast rope in qwen.
* Rework the expert selection in qwen.
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 40 | ||||
-rw-r--r-- | candle-metal-kernels/src/quantized.metal | 1 | ||||
-rw-r--r-- | candle-metal-kernels/src/sort.metal | 97 |
3 files changed, 138 insertions, 0 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 10f942b4..8e075d5a 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -21,6 +21,7 @@ const REDUCE: &str = include_str!("reduce.metal"); const RANDOM: &str = include_str!("random.metal"); const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); const QUANTIZED: &str = include_str!("quantized.metal"); +const SORT: &str = include_str!("sort.metal"); #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum Source { @@ -35,6 +36,7 @@ pub enum Source { Conv, Random, Quantized, + Sort, } pub mod copy2d { @@ -197,6 +199,7 @@ impl Kernels { Source::Conv => CONV, Source::Random => RANDOM, Source::Quantized => QUANTIZED, + Source::Sort => SORT, Source::Mfa => panic!("Invalid lib"), } } @@ -2048,5 +2051,42 @@ pub fn call_conv_transpose2d( Ok(()) } +#[allow(clippy::too_many_arguments)] +pub fn call_arg_sort( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + nrows: usize, + ncols: usize, + ncols_pad: usize, + src: BufferOffset, + dst: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Sort, name)?; + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (&src, dst, ncols as i64, ncols_pad as i64)); + + let thread_group_count = MTLSize { + width: 1, + height: nrows as u64, + depth: 1, + }; + let thread_group_size = MTLSize { + width: ncols_pad as u64, + height: 1, + depth: 1, + }; + + encoder.use_resource(src.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(dst, metal::MTLResourceUsage::Write); + encoder.set_threadgroup_memory_length(0, (ncols_pad * 4).max(16) as u64); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + #[cfg(test)] mod tests; diff --git a/candle-metal-kernels/src/quantized.metal b/candle-metal-kernels/src/quantized.metal index 9aa7b502..fef6ac54 100644 --- a/candle-metal-kernels/src/quantized.metal +++ b/candle-metal-kernels/src/quantized.metal @@ -1,3 +1,4 @@ +// Imported from https://github.com/ggerganov/llama.cpp/blob/master/ggml-metal.metal #include <metal_stdlib> using namespace metal; diff --git a/candle-metal-kernels/src/sort.metal b/candle-metal-kernels/src/sort.metal new file mode 100644 index 00000000..d71ab822 --- /dev/null +++ b/candle-metal-kernels/src/sort.metal @@ -0,0 +1,97 @@ +// Imported from https://github.com/ggerganov/llama.cpp/blob/master/ggml-metal.metal +#include <metal_stdlib> +using namespace metal; + +#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; } +#define SORT_ASC 1 +#define SORT_DESC 0 + +template<int order, typename T> +METAL_FUNC void argsort( + device const T * x, + device uint32_t * dst, + constant int64_t & ncols, + constant int64_t & ncols_pad, + threadgroup uint32_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]) { + int col = tpitg[0]; + int row = tgpig[1]; + + if (col >= ncols_pad) return; + + device const T * x_row = x + row * ncols; + threadgroup uint32_t * dst_row = shared_values; + + // initialize indices + dst_row[col] = col; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int k = 2; k <= ncols_pad; k *= 2) { + for (int j = k / 2; j > 0; j /= 2) { + int ixj = col ^ j; + if (ixj > col) { + if ((col & k) == 0) { + if (dst_row[col] >= ncols || + (dst_row[ixj] < ncols && (order == SORT_ASC ? + x_row[dst_row[col]] > x_row[dst_row[ixj]] : + x_row[dst_row[col]] < x_row[dst_row[ixj]])) + ) { + SWAP(dst_row[col], dst_row[ixj]); + } + } else { + if (dst_row[ixj] >= ncols || + (dst_row[col] < ncols && (order == SORT_ASC ? + x_row[dst_row[col]] < x_row[dst_row[ixj]] : + x_row[dst_row[col]] > x_row[dst_row[ixj]])) + ) { + SWAP(dst_row[col], dst_row[ixj]); + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + } + + // copy the result to dst without the padding + if (col < ncols) { + dst[row * ncols + col] = dst_row[col]; + } +} + +#define ARGSORT(T, RUST_T) \ +kernel void asort_asc_##RUST_T( \ + device const T * x, \ + device uint32_t * dst, \ + constant int64_t & ncols, \ + constant int64_t & ncols_pad, \ + threadgroup uint32_t * shared_values [[threadgroup(0)]], \ + uint3 tgpig[[threadgroup_position_in_grid]], \ + uint3 tpitg[[thread_position_in_threadgroup]] \ +) { \ + argsort<SORT_ASC, T>(x, dst, ncols, ncols_pad, shared_values, tgpig, tpitg); \ +} \ +kernel void asort_desc_##RUST_T( \ + device const T * x, \ + device uint32_t * dst, \ + constant int64_t & ncols, \ + constant int64_t & ncols_pad, \ + threadgroup uint32_t * shared_values [[threadgroup(0)]], \ + uint3 tgpig[[threadgroup_position_in_grid]], \ + uint3 tpitg[[thread_position_in_threadgroup]] \ +) { \ + argsort<SORT_DESC, T>(x, dst, ncols, ncols_pad, shared_values, tgpig, tpitg); \ +} \ + +ARGSORT(float, f32) +ARGSORT(half, f16) +ARGSORT(uint8_t, u8) +ARGSORT(uint32_t, u32) + +#if __METAL_VERSION__ >= 220 +ARGSORT(int64_t, i64) +#endif +#if defined(__HAVE_BFLOAT__) +ARGSORT(bfloat, bf16) +#endif |