diff options
Diffstat (limited to 'candle-metal-kernels/src/lib.rs')
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 10f942b4..8e075d5a 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -21,6 +21,7 @@ const REDUCE: &str = include_str!("reduce.metal"); const RANDOM: &str = include_str!("random.metal"); const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); const QUANTIZED: &str = include_str!("quantized.metal"); +const SORT: &str = include_str!("sort.metal"); #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum Source { @@ -35,6 +36,7 @@ pub enum Source { Conv, Random, Quantized, + Sort, } pub mod copy2d { @@ -197,6 +199,7 @@ impl Kernels { Source::Conv => CONV, Source::Random => RANDOM, Source::Quantized => QUANTIZED, + Source::Sort => SORT, Source::Mfa => panic!("Invalid lib"), } } @@ -2048,5 +2051,42 @@ pub fn call_conv_transpose2d( Ok(()) } +#[allow(clippy::too_many_arguments)] +pub fn call_arg_sort( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + nrows: usize, + ncols: usize, + ncols_pad: usize, + src: BufferOffset, + dst: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Sort, name)?; + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (&src, dst, ncols as i64, ncols_pad as i64)); + + let thread_group_count = MTLSize { + width: 1, + height: nrows as u64, + depth: 1, + }; + let thread_group_size = MTLSize { + width: ncols_pad as u64, + height: 1, + depth: 1, + }; + + encoder.use_resource(src.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(dst, metal::MTLResourceUsage::Write); + encoder.set_threadgroup_memory_length(0, (ncols_pad * 4).max(16) as u64); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + #[cfg(test)] mod tests; |