summaryrefslogtreecommitdiff
path: root/candle-kernels
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-05-24 15:58:01 +0200
committerGitHub <noreply@github.com>2024-05-24 15:58:01 +0200
commit1df2bddccfbb4ab511a8cc3a87476d1fa72416bc (patch)
tree3633bc51e3bac3d542d9dfe06d509db20f5374e9 /candle-kernels
parent6f0b807ffd553fed27325a2a118b0e30bb6d9cbd (diff)
downloadcandle-1df2bddccfbb4ab511a8cc3a87476d1fa72416bc.tar.gz
candle-1df2bddccfbb4ab511a8cc3a87476d1fa72416bc.tar.bz2
candle-1df2bddccfbb4ab511a8cc3a87476d1fa72416bc.zip
Add the layernorm specialized op. (#2212)
* Add the layernorm cuda kernels. * Dedicated layer norm op. * Add the slower variant. * Plug the cuda implementation. * Add the metal variant. * Add a dedicated test. * Bugfix.
Diffstat (limited to 'candle-kernels')
-rw-r--r--candle-kernels/src/reduce.cu84
1 files changed, 84 insertions, 0 deletions
diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu
index 4dbd8dcc..aaac24a1 100644
--- a/candle-kernels/src/reduce.cu
+++ b/candle-kernels/src/reduce.cu
@@ -50,6 +50,15 @@ fast_sum(const size_t src_numel, const size_t el_to_sum_per_block,
dst[dst_id] = shr[0];
}
+static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
+#pragma unroll
+ for (int mask = 16; mask > 0; mask >>= 1) {
+ a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
+ a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
+ }
+ return a;
+}
+
static __device__ __forceinline__ float warp_reduce_sum(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
@@ -58,6 +67,70 @@ static __device__ __forceinline__ float warp_reduce_sum(float x) {
return x;
}
+// LayerNorm implementation adapted from ggml, accumulation is made using f32.
+// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L477
+template <typename T>
+__device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta, const int ncols, const float eps) {
+ const int row = blockIdx.x*blockDim.y + threadIdx.y;
+ const int tid = threadIdx.x;
+ const int block_size = blockDim.x;
+
+ float2 mean_var = make_float2(0.f, 0.f);
+
+ for (int col = tid; col < ncols; col += block_size) {
+ const float xi = x[row*ncols + col];
+ mean_var.x += xi;
+ mean_var.y += xi * xi;
+ }
+
+ // sum up partial sums
+ mean_var = warp_reduce_sum(mean_var);
+ if (block_size > WARP_SIZE) {
+ __shared__ float2 s_sum[32];
+ int warp_id = threadIdx.x / WARP_SIZE;
+ int lane_id = threadIdx.x % WARP_SIZE;
+ if (lane_id == 0) {
+ s_sum[warp_id] = mean_var;
+ }
+ __syncthreads();
+ mean_var = s_sum[lane_id];
+ mean_var = warp_reduce_sum(mean_var);
+ }
+
+ const float mean = mean_var.x / ncols;
+ const float var = mean_var.y / ncols - mean * mean;
+ const float inv_std = rsqrtf(var + eps);
+
+ if (alpha == nullptr && beta == nullptr) {
+ for (int col = tid; col < ncols; col += block_size) {
+ float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
+ dst[row*ncols + col] = static_cast<T>(lhs);
+ }
+ }
+ else if (alpha == nullptr && beta != nullptr) {
+ for (int col = tid; col < ncols; col += block_size) {
+ float b = static_cast<float>(beta[col]);
+ float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
+ dst[row*ncols + col] = static_cast<T>(lhs + b);
+ }
+ }
+ else if (alpha != nullptr && beta == nullptr) {
+ for (int col = tid; col < ncols; col += block_size) {
+ float a = static_cast<float>(alpha[col]);
+ float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
+ dst[row*ncols + col] = static_cast<T>(lhs * a);
+ }
+ }
+ else {
+ for (int col = tid; col < ncols; col += block_size) {
+ float a = static_cast<float>(alpha[col]);
+ float b = static_cast<float>(beta[col]);
+ float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
+ dst[row*ncols + col] = static_cast<T>(lhs * a + b);
+ }
+ }
+}
+
// RmsNorm implementation adapted from ggml, accumulation is made using f32.
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L523
template <typename T>
@@ -461,6 +534,13 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
rmsnorm<TYPENAME>(src, dst, alpha, n_cols, eps); \
} \
+#define LAYERNORM_OP(TYPENAME, FN_NAME) \
+ extern "C" __global__ void FN_NAME( \
+ const TYPENAME *src, TYPENAME *dst, const TYPENAME *alpha, \
+ const TYPENAME *beta, const int n_cols, const float eps) { \
+ layernorm<TYPENAME>(src, dst, alpha, beta, n_cols, eps); \
+ } \
+
#define ROPE_OP(TYPENAME, FN_NAME, FN_NAME_I, FN_NAME_THD) \
extern "C" __global__ void FN_NAME_I( \
const TYPENAME *src, \
@@ -496,6 +576,7 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
#if __CUDA_ARCH__ >= 800
SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16)
+LAYERNORM_OP(__nv_bfloat16, layernorm_bf16)
ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16)
SUM_OP(__nv_bfloat16, sum_bf16)
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16)
@@ -504,6 +585,7 @@ FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argm
#if __CUDA_ARCH__ >= 530
SOFTMAX_OP(__half, float, softmax_f16)
RMSNORM_OP(__half, rmsnorm_f16)
+LAYERNORM_OP(__half, layernorm_f16)
ROPE_OP(__half, rope_f16, rope_i_f16, rope_thd_f16)
SUM_OP(__half, sum_f16)
FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16)
@@ -516,6 +598,8 @@ SOFTMAX_OP(float, float, softmax_f32)
SOFTMAX_OP(double, double, softmax_f64)
RMSNORM_OP(float, rmsnorm_f32)
RMSNORM_OP(double, rmsnorm_f64)
+LAYERNORM_OP(float, layernorm_f32)
+LAYERNORM_OP(double, layernorm_f64)
ROPE_OP(float, rope_f32, rope_i_f32, rope_thd_f32)
ROPE_OP(double, rope_f64, rope_i_f64, rope_thd_f64)