diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-03-24 22:48:52 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-24 22:48:52 +0100 |
commit | 1b98f84a2baa23192b97e36131011da658bfa1c2 (patch) | |
tree | 92c4e9e8a263edfc8d3fedeab2cc02271d87d51e /candle-metal-kernels | |
parent | cf7d7fcf2f20c24aae633483c3a107c1219a7f9a (diff) | |
download | candle-1b98f84a2baa23192b97e36131011da658bfa1c2.tar.gz candle-1b98f84a2baa23192b97e36131011da658bfa1c2.tar.bz2 candle-1b98f84a2baa23192b97e36131011da658bfa1c2.zip |
Fast kernels for rotary embeddings. (#1928)
* Fast kernels for rotary embeddings.
* Add a test for the fast CPU kernel.
* Rope cuda bindings.
* Cuda kernel.
* Metal kernel (part 1).
* Cuda kernels.
* Finish the metal kernel.
* Use the new kernels in the quantized example.
* Fix warning.
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 41 | ||||
-rw-r--r-- | candle-metal-kernels/src/reduce.metal | 23 |
2 files changed, 64 insertions, 0 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index e17365a0..e83814a8 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -809,6 +809,47 @@ pub fn call_rms_norm( } #[allow(clippy::too_many_arguments)] +pub fn call_rope_i( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: &'static str, + bh: usize, + td: usize, + src: &Buffer, + src_offset: usize, + cos: &Buffer, + cos_offset: usize, + sin: &Buffer, + sin_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + bh, + td, + (src, src_offset), + (cos, cos_offset), + (sin, sin_offset), + output + ) + ); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, (bh * td) / 2); + encoder.use_resource(src, metal::MTLResourceUsage::Read); + encoder.use_resource(cos, metal::MTLResourceUsage::Read); + encoder.use_resource(sin, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] pub fn call_affine( device: &Device, command_buffer: &CommandBufferRef, diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 3c3cbc14..fa980dea 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -313,6 +313,26 @@ kernel void NAME( } \ } \ +#define ROPEI(FN_NAME, TYPENAME) \ +kernel void FN_NAME( \ + constant size_t &bh, \ + constant size_t &td, \ + device const TYPENAME *src, \ + device const TYPENAME *cos, \ + device const TYPENAME *sin, \ + device TYPENAME *dst, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + if (2 * tid >= bh * td) { \ + return; \ + } \ + size_t rope_idx = tid % (td / 2); \ + TYPENAME c = cos[rope_idx]; \ + TYPENAME s = sin[rope_idx]; \ + dst[2 * tid] = src[2 * tid] * c - src[2 * tid + 1] * s; \ + dst[2 * tid + 1] = src[2 * tid] * s + src[2 * tid + 1] * c; \ +}\ + REDUCE(x + y, fast_sum_f32_strided, float, 0) REDUCE(x + y, fast_sum_u32_strided, uint, 0) REDUCE(x + y, fast_sum_f16_strided, half, 0) @@ -341,6 +361,8 @@ SOFTMAX(softmax_f32, float) SOFTMAX(softmax_f16, half) RMSNORM(rmsnorm_f32, float) RMSNORM(rmsnorm_f16, half) +ROPEI(rope_i_f32, float) +ROPEI(rope_i_f16, half) #if __METAL_VERSION__ >= 220 REDUCE(x + y, fast_sum_i64_strided, int64_t, 0) @@ -359,4 +381,5 @@ ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF) ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF) SOFTMAX(softmax_bf16, bfloat) RMSNORM(rmsnorm_bf16, bfloat) +ROPEI(rope_i_bf16, bfloat) #endif |