diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-05-24 15:58:01 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-24 15:58:01 +0200 |
commit | 1df2bddccfbb4ab511a8cc3a87476d1fa72416bc (patch) | |
tree | 3633bc51e3bac3d542d9dfe06d509db20f5374e9 /candle-metal-kernels | |
parent | 6f0b807ffd553fed27325a2a118b0e30bb6d9cbd (diff) | |
download | candle-1df2bddccfbb4ab511a8cc3a87476d1fa72416bc.tar.gz candle-1df2bddccfbb4ab511a8cc3a87476d1fa72416bc.tar.bz2 candle-1df2bddccfbb4ab511a8cc3a87476d1fa72416bc.zip |
Add the layernorm specialized op. (#2212)
* Add the layernorm cuda kernels.
* Dedicated layer norm op.
* Add the slower variant.
* Plug the cuda implementation.
* Add the metal variant.
* Add a dedicated test.
* Bugfix.
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 63 | ||||
-rw-r--r-- | candle-metal-kernels/src/reduce.metal | 81 |
2 files changed, 144 insertions, 0 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 814ca0b9..aa157a2f 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -739,6 +739,69 @@ pub fn call_rms_norm( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.set_threadgroup_memory_length(0, (width * 4).max(16) as u64); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_layer_norm( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: &'static str, + length: usize, + elements_to_sum: usize, + eps: f32, + input: &Buffer, + input_offset: usize, + alpha: &Buffer, + alpha_offset: usize, + beta: &Buffer, + beta_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + length, + elements_to_sum, + (input, input_offset), + output, + (alpha, alpha_offset), + (beta, beta_offset), + eps + ) + ); + + let out_length = length / elements_to_sum; + + let thread_group_count = MTLSize { + width: out_length as u64, + height: 1, + depth: 1, + }; + + let width = std::cmp::min( + pipeline.max_total_threads_per_threadgroup(), + elements_to_sum as u64, + ) + .next_power_of_two(); + + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.set_threadgroup_memory_length(0, (width * 8).max(32) as u64); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); Ok(()) diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 14bfb297..e009ca1d 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -353,6 +353,65 @@ METAL_FUNC void rmsnorm( } } +template<typename T> +METAL_FUNC void layernorm( + constant size_t & src_numel, + constant size_t & el_to_sum_per_block, + device const T * src, + device T * dst, + device const T * alpha, + device const T * beta, + constant float & eps, + uint id, + uint tid, + uint dst_id, + uint block_dim, + threadgroup float * shared_memory +) { + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); + size_t idx = start_idx + tid; + + float tmp1 = 0; + float tmp2 = 0; + while (idx < stop_idx) { + tmp1 += float(src[idx]); + tmp2 += float(src[idx]) * float(src[idx]); + idx += block_dim; + } + shared_memory[tid] = tmp1; + shared_memory[tid + block_dim] = tmp2; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint s = block_dim / 2; s > 0; s >>= 1) { + if (tid < s) { + shared_memory[tid] = shared_memory[tid] + shared_memory[tid + s]; + shared_memory[block_dim + tid] = shared_memory[block_dim + tid] + shared_memory[block_dim + tid + s]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + /* wait for shared_memory[0] to be filled */ + threadgroup_barrier(mem_flags::mem_threadgroup); + + float mean = shared_memory[0] / float(el_to_sum_per_block); + float var = shared_memory[block_dim] / float(el_to_sum_per_block) - mean * mean; + float inv_norm = 1.0f / sqrt(var + eps); + idx = start_idx + tid; + while (idx < stop_idx) { + float val = (float(src[idx]) - mean) * inv_norm; + if (alpha != nullptr) { + val *= float(alpha[idx - start_idx]); + } + if (beta != nullptr) { + val += float(beta[idx - start_idx]); + } + dst[idx] = T(val); + idx += block_dim; + } +} + #define RMSNORM(NAME, T) \ kernel void NAME( \ constant size_t &src_numel, \ @@ -371,6 +430,25 @@ kernel void NAME( \ rmsnorm<T>(src_numel, el_to_sum_per_block, src, dst, alpha, eps, id, tid, dst_id, block_dim, shared_memory); \ } \ +#define LAYERNORM(NAME, T) \ +kernel void NAME( \ + constant size_t &src_numel, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device T *dst, \ + device const T *alpha, \ + device const T *beta, \ + constant float &eps, \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + threadgroup float shared_memory[THREADGROUP_SIZE]; \ + shared_memory[tid] = 0; \ + layernorm<T>(src_numel, el_to_sum_per_block, src, dst, alpha, beta, eps, id, tid, dst_id, block_dim, shared_memory); \ +} \ + template<typename T> METAL_FUNC void ropei( constant size_t &bh, @@ -511,6 +589,8 @@ SOFTMAX(softmax_f32, float) SOFTMAX(softmax_f16, half) RMSNORM(rmsnorm_f32, float) RMSNORM(rmsnorm_f16, half) +LAYERNORM(layernorm_f32, float) +LAYERNORM(layernorm_f16, half) ROPE(rope_f32, rope_i_f32, rope_thd_f32, float) ROPE(rope_f16, rope_i_f16, rope_thd_f16, half) @@ -535,5 +615,6 @@ ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF) ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF) SOFTMAX(softmax_bf16, bfloat) RMSNORM(rmsnorm_bf16, bfloat) +LAYERNORM(layernorm_bf16, bfloat) ROPE(rope_bf16, rope_i_bf16, rope_thd_bf16, bfloat) #endif |