diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-03-24 22:48:52 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-24 22:48:52 +0100 |
commit | 1b98f84a2baa23192b97e36131011da658bfa1c2 (patch) | |
tree | 92c4e9e8a263edfc8d3fedeab2cc02271d87d51e /candle-metal-kernels/src/lib.rs | |
parent | cf7d7fcf2f20c24aae633483c3a107c1219a7f9a (diff) | |
download | candle-1b98f84a2baa23192b97e36131011da658bfa1c2.tar.gz candle-1b98f84a2baa23192b97e36131011da658bfa1c2.tar.bz2 candle-1b98f84a2baa23192b97e36131011da658bfa1c2.zip |
Fast kernels for rotary embeddings. (#1928)
* Fast kernels for rotary embeddings.
* Add a test for the fast CPU kernel.
* Rope cuda bindings.
* Cuda kernel.
* Metal kernel (part 1).
* Cuda kernels.
* Finish the metal kernel.
* Use the new kernels in the quantized example.
* Fix warning.
Diffstat (limited to 'candle-metal-kernels/src/lib.rs')
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index e17365a0..e83814a8 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -809,6 +809,47 @@ pub fn call_rms_norm( } #[allow(clippy::too_many_arguments)] +pub fn call_rope_i( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: &'static str, + bh: usize, + td: usize, + src: &Buffer, + src_offset: usize, + cos: &Buffer, + cos_offset: usize, + sin: &Buffer, + sin_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + bh, + td, + (src, src_offset), + (cos, cos_offset), + (sin, sin_offset), + output + ) + ); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, (bh * td) / 2); + encoder.use_resource(src, metal::MTLResourceUsage::Read); + encoder.use_resource(cos, metal::MTLResourceUsage::Read); + encoder.use_resource(sin, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] pub fn call_affine( device: &Device, command_buffer: &CommandBufferRef, |