diff options
Diffstat (limited to 'candle-metal-kernels/src/reduce.metal')
-rw-r--r-- | candle-metal-kernels/src/reduce.metal | 81 |
1 files changed, 81 insertions, 0 deletions
diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 14bfb297..e009ca1d 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -353,6 +353,65 @@ METAL_FUNC void rmsnorm( } } +template<typename T> +METAL_FUNC void layernorm( + constant size_t & src_numel, + constant size_t & el_to_sum_per_block, + device const T * src, + device T * dst, + device const T * alpha, + device const T * beta, + constant float & eps, + uint id, + uint tid, + uint dst_id, + uint block_dim, + threadgroup float * shared_memory +) { + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); + size_t idx = start_idx + tid; + + float tmp1 = 0; + float tmp2 = 0; + while (idx < stop_idx) { + tmp1 += float(src[idx]); + tmp2 += float(src[idx]) * float(src[idx]); + idx += block_dim; + } + shared_memory[tid] = tmp1; + shared_memory[tid + block_dim] = tmp2; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint s = block_dim / 2; s > 0; s >>= 1) { + if (tid < s) { + shared_memory[tid] = shared_memory[tid] + shared_memory[tid + s]; + shared_memory[block_dim + tid] = shared_memory[block_dim + tid] + shared_memory[block_dim + tid + s]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + /* wait for shared_memory[0] to be filled */ + threadgroup_barrier(mem_flags::mem_threadgroup); + + float mean = shared_memory[0] / float(el_to_sum_per_block); + float var = shared_memory[block_dim] / float(el_to_sum_per_block) - mean * mean; + float inv_norm = 1.0f / sqrt(var + eps); + idx = start_idx + tid; + while (idx < stop_idx) { + float val = (float(src[idx]) - mean) * inv_norm; + if (alpha != nullptr) { + val *= float(alpha[idx - start_idx]); + } + if (beta != nullptr) { + val += float(beta[idx - start_idx]); + } + dst[idx] = T(val); + idx += block_dim; + } +} + #define RMSNORM(NAME, T) \ kernel void NAME( \ constant size_t &src_numel, \ @@ -371,6 +430,25 @@ kernel void NAME( \ rmsnorm<T>(src_numel, el_to_sum_per_block, src, dst, alpha, eps, id, tid, dst_id, block_dim, shared_memory); \ } \ +#define LAYERNORM(NAME, T) \ +kernel void NAME( \ + constant size_t &src_numel, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device T *dst, \ + device const T *alpha, \ + device const T *beta, \ + constant float &eps, \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + threadgroup float shared_memory[THREADGROUP_SIZE]; \ + shared_memory[tid] = 0; \ + layernorm<T>(src_numel, el_to_sum_per_block, src, dst, alpha, beta, eps, id, tid, dst_id, block_dim, shared_memory); \ +} \ + template<typename T> METAL_FUNC void ropei( constant size_t &bh, @@ -511,6 +589,8 @@ SOFTMAX(softmax_f32, float) SOFTMAX(softmax_f16, half) RMSNORM(rmsnorm_f32, float) RMSNORM(rmsnorm_f16, half) +LAYERNORM(layernorm_f32, float) +LAYERNORM(layernorm_f16, half) ROPE(rope_f32, rope_i_f32, rope_thd_f32, float) ROPE(rope_f16, rope_i_f16, rope_thd_f16, half) @@ -535,5 +615,6 @@ ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF) ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF) SOFTMAX(softmax_bf16, bfloat) RMSNORM(rmsnorm_bf16, bfloat) +LAYERNORM(layernorm_bf16, bfloat) ROPE(rope_bf16, rope_i_bf16, rope_thd_bf16, bfloat) #endif |