summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-metal-kernels/src/lib.rs')
-rw-r--r--candle-metal-kernels/src/lib.rs40
1 files changed, 40 insertions, 0 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 10f942b4..8e075d5a 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -21,6 +21,7 @@ const REDUCE: &str = include_str!("reduce.metal");
const RANDOM: &str = include_str!("random.metal");
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
const QUANTIZED: &str = include_str!("quantized.metal");
+const SORT: &str = include_str!("sort.metal");
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Source {
@@ -35,6 +36,7 @@ pub enum Source {
Conv,
Random,
Quantized,
+ Sort,
}
pub mod copy2d {
@@ -197,6 +199,7 @@ impl Kernels {
Source::Conv => CONV,
Source::Random => RANDOM,
Source::Quantized => QUANTIZED,
+ Source::Sort => SORT,
Source::Mfa => panic!("Invalid lib"),
}
}
@@ -2048,5 +2051,42 @@ pub fn call_conv_transpose2d(
Ok(())
}
+#[allow(clippy::too_many_arguments)]
+pub fn call_arg_sort(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: &'static str,
+ nrows: usize,
+ ncols: usize,
+ ncols_pad: usize,
+ src: BufferOffset,
+ dst: &Buffer,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Sort, name)?;
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(encoder, (&src, dst, ncols as i64, ncols_pad as i64));
+
+ let thread_group_count = MTLSize {
+ width: 1,
+ height: nrows as u64,
+ depth: 1,
+ };
+ let thread_group_size = MTLSize {
+ width: ncols_pad as u64,
+ height: 1,
+ depth: 1,
+ };
+
+ encoder.use_resource(src.buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(dst, metal::MTLResourceUsage::Write);
+ encoder.set_threadgroup_memory_length(0, (ncols_pad * 4).max(16) as u64);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.end_encoding();
+ Ok(())
+}
+
#[cfg(test)]
mod tests;