diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-05-24 15:58:01 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-24 15:58:01 +0200 |
commit | 1df2bddccfbb4ab511a8cc3a87476d1fa72416bc (patch) | |
tree | 3633bc51e3bac3d542d9dfe06d509db20f5374e9 /candle-kernels | |
parent | 6f0b807ffd553fed27325a2a118b0e30bb6d9cbd (diff) | |
download | candle-1df2bddccfbb4ab511a8cc3a87476d1fa72416bc.tar.gz candle-1df2bddccfbb4ab511a8cc3a87476d1fa72416bc.tar.bz2 candle-1df2bddccfbb4ab511a8cc3a87476d1fa72416bc.zip |
Add the layernorm specialized op. (#2212)
* Add the layernorm cuda kernels.
* Dedicated layer norm op.
* Add the slower variant.
* Plug the cuda implementation.
* Add the metal variant.
* Add a dedicated test.
* Bugfix.
Diffstat (limited to 'candle-kernels')
-rw-r--r-- | candle-kernels/src/reduce.cu | 84 |
1 files changed, 84 insertions, 0 deletions
diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu index 4dbd8dcc..aaac24a1 100644 --- a/candle-kernels/src/reduce.cu +++ b/candle-kernels/src/reduce.cu @@ -50,6 +50,15 @@ fast_sum(const size_t src_numel, const size_t el_to_sum_per_block, dst[dst_id] = shr[0]; } +static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32); + a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32); + } + return a; +} + static __device__ __forceinline__ float warp_reduce_sum(float x) { #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { @@ -58,6 +67,70 @@ static __device__ __forceinline__ float warp_reduce_sum(float x) { return x; } +// LayerNorm implementation adapted from ggml, accumulation is made using f32. +// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L477 +template <typename T> +__device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta, const int ncols, const float eps) { + const int row = blockIdx.x*blockDim.y + threadIdx.y; + const int tid = threadIdx.x; + const int block_size = blockDim.x; + + float2 mean_var = make_float2(0.f, 0.f); + + for (int col = tid; col < ncols; col += block_size) { + const float xi = x[row*ncols + col]; + mean_var.x += xi; + mean_var.y += xi * xi; + } + + // sum up partial sums + mean_var = warp_reduce_sum(mean_var); + if (block_size > WARP_SIZE) { + __shared__ float2 s_sum[32]; + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = mean_var; + } + __syncthreads(); + mean_var = s_sum[lane_id]; + mean_var = warp_reduce_sum(mean_var); + } + + const float mean = mean_var.x / ncols; + const float var = mean_var.y / ncols - mean * mean; + const float inv_std = rsqrtf(var + eps); + + if (alpha == nullptr && beta == nullptr) { + for (int col = tid; col < ncols; col += block_size) { + float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std; + dst[row*ncols + col] = static_cast<T>(lhs); + } + } + else if (alpha == nullptr && beta != nullptr) { + for (int col = tid; col < ncols; col += block_size) { + float b = static_cast<float>(beta[col]); + float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std; + dst[row*ncols + col] = static_cast<T>(lhs + b); + } + } + else if (alpha != nullptr && beta == nullptr) { + for (int col = tid; col < ncols; col += block_size) { + float a = static_cast<float>(alpha[col]); + float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std; + dst[row*ncols + col] = static_cast<T>(lhs * a); + } + } + else { + for (int col = tid; col < ncols; col += block_size) { + float a = static_cast<float>(alpha[col]); + float b = static_cast<float>(beta[col]); + float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std; + dst[row*ncols + col] = static_cast<T>(lhs * a + b); + } + } +} + // RmsNorm implementation adapted from ggml, accumulation is made using f32. // https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L523 template <typename T> @@ -461,6 +534,13 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block, rmsnorm<TYPENAME>(src, dst, alpha, n_cols, eps); \ } \ +#define LAYERNORM_OP(TYPENAME, FN_NAME) \ + extern "C" __global__ void FN_NAME( \ + const TYPENAME *src, TYPENAME *dst, const TYPENAME *alpha, \ + const TYPENAME *beta, const int n_cols, const float eps) { \ + layernorm<TYPENAME>(src, dst, alpha, beta, n_cols, eps); \ + } \ + #define ROPE_OP(TYPENAME, FN_NAME, FN_NAME_I, FN_NAME_THD) \ extern "C" __global__ void FN_NAME_I( \ const TYPENAME *src, \ @@ -496,6 +576,7 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block, #if __CUDA_ARCH__ >= 800 SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16) RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16) +LAYERNORM_OP(__nv_bfloat16, layernorm_bf16) ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16) SUM_OP(__nv_bfloat16, sum_bf16) FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16) @@ -504,6 +585,7 @@ FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argm #if __CUDA_ARCH__ >= 530 SOFTMAX_OP(__half, float, softmax_f16) RMSNORM_OP(__half, rmsnorm_f16) +LAYERNORM_OP(__half, layernorm_f16) ROPE_OP(__half, rope_f16, rope_i_f16, rope_thd_f16) SUM_OP(__half, sum_f16) FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16) @@ -516,6 +598,8 @@ SOFTMAX_OP(float, float, softmax_f32) SOFTMAX_OP(double, double, softmax_f64) RMSNORM_OP(float, rmsnorm_f32) RMSNORM_OP(double, rmsnorm_f64) +LAYERNORM_OP(float, layernorm_f32) +LAYERNORM_OP(double, layernorm_f64) ROPE_OP(float, rope_f32, rope_i_f32, rope_thd_f32) ROPE_OP(double, rope_f64, rope_i_f64, rope_thd_f64) |