diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-04-05 08:32:58 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-05 08:32:58 +0200 |
commit | 2ac302a5d170953a1d2fe850645563fc55d1567f (patch) | |
tree | b35b32efe5c8eac25a9b5681fb0778ef84e57d0e /candle-metal-kernels/src | |
parent | ace282e5c2ef24ca2fb90683babb852936d4df17 (diff) | |
download | candle-2ac302a5d170953a1d2fe850645563fc55d1567f.tar.gz candle-2ac302a5d170953a1d2fe850645563fc55d1567f.tar.bz2 candle-2ac302a5d170953a1d2fe850645563fc55d1567f.zip |
Add the rope THD kernel. (#2014)
* Add the rope THD kernel.
* Cuda kernel for rope-thd.
* Add the metal kernels.
* Add a dedicated test.
Diffstat (limited to 'candle-metal-kernels/src')
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 45 | ||||
-rw-r--r-- | candle-metal-kernels/src/reduce.metal | 48 |
2 files changed, 89 insertions, 4 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 5af48fae..4cff9bda 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -850,6 +850,51 @@ pub fn call_rope_i( } #[allow(clippy::too_many_arguments)] +pub fn call_rope_thd( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: &'static str, + b: usize, + t: usize, + h: usize, + d: usize, + src: &Buffer, + src_offset: usize, + cos: &Buffer, + cos_offset: usize, + sin: &Buffer, + sin_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + b, + t, + h, + d, + (src, src_offset), + (cos, cos_offset), + (sin, sin_offset), + output + ) + ); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, (b * t * h * d) / 2); + encoder.use_resource(src, metal::MTLResourceUsage::Read); + encoder.use_resource(cos, metal::MTLResourceUsage::Read); + encoder.use_resource(sin, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] pub fn call_rope( device: &Device, command_buffer: &CommandBufferRef, diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index acb69299..14bfb297 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -418,7 +418,34 @@ METAL_FUNC void rope( dst[i2] = src[i1] * s + src[i2] * c; } -#define ROPEI(FN_NAME, FN_NAME_I, TYPENAME) \ +template<typename T> +METAL_FUNC void rope_thd( + constant size_t &b, + constant size_t &t, + constant size_t &h, + constant size_t &d, + device const T *src, + device const T *cos, + device const T *sin, + device T *dst, + uint idx +) { + if (2 * idx >= b * t * h * d) { + return; + } + const size_t i_bth = idx / (d / 2); + const size_t i_d = idx - (d / 2) * i_bth; + const size_t i_t = (i_bth / h) % t; + const size_t i1 = i_bth * d + i_d; + const size_t i2 = i1 + d / 2; + const size_t i_cs = i_t * (d / 2) + i_d; + T c = cos[i_cs]; + T s = sin[i_cs]; + dst[i1] = src[i1] * c - src[i2] * s; + dst[i2] = src[i1] * s + src[i2] * c; +} + +#define ROPE(FN_NAME, FN_NAME_I, FN_NAME_THD, TYPENAME) \ kernel void FN_NAME_I( \ constant size_t &bh, \ constant size_t &td, \ @@ -442,6 +469,19 @@ kernel void FN_NAME( \ ) { \ rope<TYPENAME>(bh, td, d, src, cos, sin, dst, idx); \ }\ +kernel void FN_NAME_THD( \ + constant size_t &b, \ + constant size_t &t, \ + constant size_t &h, \ + constant size_t &d, \ + device const TYPENAME *src, \ + device const TYPENAME *cos, \ + device const TYPENAME *sin, \ + device TYPENAME *dst, \ + uint idx [[ thread_position_in_grid ]] \ +) { \ + rope_thd<TYPENAME>(b, t, h, d, src, cos, sin, dst, idx); \ +}\ REDUCE(x + y, fast_sum_f32_strided, float, 0) REDUCE(x + y, fast_sum_u32_strided, uint, 0) @@ -471,8 +511,8 @@ SOFTMAX(softmax_f32, float) SOFTMAX(softmax_f16, half) RMSNORM(rmsnorm_f32, float) RMSNORM(rmsnorm_f16, half) -ROPEI(rope_f32, rope_i_f32, float) -ROPEI(rope_f16, rope_i_f16, half) +ROPE(rope_f32, rope_i_f32, rope_thd_f32, float) +ROPE(rope_f16, rope_i_f16, rope_thd_f16, half) #if __METAL_VERSION__ >= 220 REDUCE(x + y, fast_sum_i64_strided, int64_t, 0) @@ -495,5 +535,5 @@ ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF) ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF) SOFTMAX(softmax_bf16, bfloat) RMSNORM(rmsnorm_bf16, bfloat) -ROPEI(rope_bf16, rope_i_bf16, bfloat) +ROPE(rope_bf16, rope_i_bf16, rope_thd_bf16, bfloat) #endif |