diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-04-05 08:32:58 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-05 08:32:58 +0200 |
commit | 2ac302a5d170953a1d2fe850645563fc55d1567f (patch) | |
tree | b35b32efe5c8eac25a9b5681fb0778ef84e57d0e /candle-kernels | |
parent | ace282e5c2ef24ca2fb90683babb852936d4df17 (diff) | |
download | candle-2ac302a5d170953a1d2fe850645563fc55d1567f.tar.gz candle-2ac302a5d170953a1d2fe850645563fc55d1567f.tar.bz2 candle-2ac302a5d170953a1d2fe850645563fc55d1567f.zip |
Add the rope THD kernel. (#2014)
* Add the rope THD kernel.
* Cuda kernel for rope-thd.
* Add the metal kernels.
* Add a dedicated test.
Diffstat (limited to 'candle-kernels')
-rw-r--r-- | candle-kernels/src/reduce.cu | 48 |
1 files changed, 43 insertions, 5 deletions
diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu index 2af81c42..4dbd8dcc 100644 --- a/candle-kernels/src/reduce.cu +++ b/candle-kernels/src/reduce.cu @@ -180,6 +180,33 @@ __device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const } template <typename T> +__device__ void rope_thd( + const T * src, + const T * cos, + const T * sin, + T * dst, + const uint32_t b, + const uint32_t t, + const uint32_t h, + const uint32_t d +) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (2 * idx >= b * t * h * d) return; + + uint32_t i_bth = idx / (d / 2); + uint32_t i_d = idx - (d / 2) * i_bth; + uint32_t i_t = (i_bth / h) % t; + uint32_t i1 = i_bth * d + i_d; + uint32_t i2 = i1 + d / 2; + uint32_t i_cs = i_t * (d / 2) + i_d; + T c = cos[i_cs]; + T s = sin[i_cs]; + + dst[i1] = src[i1] * c - src[i2] * s; + dst[i2] = src[i1] * s + src[i2] * c; +} + +template <typename T> __device__ void fast_max(const size_t src_numel, const size_t el_to_sum_per_block, const size_t num_dims, const size_t *info, const T *src, T *dst) { @@ -434,7 +461,7 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block, rmsnorm<TYPENAME>(src, dst, alpha, n_cols, eps); \ } \ -#define ROPE_OP(TYPENAME, FN_NAME, FN_NAME_I) \ +#define ROPE_OP(TYPENAME, FN_NAME, FN_NAME_I, FN_NAME_THD) \ extern "C" __global__ void FN_NAME_I( \ const TYPENAME *src, \ const TYPENAME *cos, \ @@ -454,11 +481,22 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block, const uint32_t d) { \ rope<TYPENAME>(src, cos, sin, dst, bh, td, d); \ } \ + extern "C" __global__ void FN_NAME_THD( \ + const TYPENAME *src, \ + const TYPENAME *cos, \ + const TYPENAME *sin, \ + TYPENAME *dst, \ + const uint32_t b, \ + const uint32_t t, \ + const uint32_t h, \ + const uint32_t d) { \ + rope_thd<TYPENAME>(src, cos, sin, dst, b, t, h, d); \ + } \ #if __CUDA_ARCH__ >= 800 SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16) RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16) -ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16) +ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16) SUM_OP(__nv_bfloat16, sum_bf16) FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16) #endif @@ -466,7 +504,7 @@ FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argm #if __CUDA_ARCH__ >= 530 SOFTMAX_OP(__half, float, softmax_f16) RMSNORM_OP(__half, rmsnorm_f16) -ROPE_OP(__half, rope_f16, rope_i_f16) +ROPE_OP(__half, rope_f16, rope_i_f16, rope_thd_f16) SUM_OP(__half, sum_f16) FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16) #endif @@ -478,8 +516,8 @@ SOFTMAX_OP(float, float, softmax_f32) SOFTMAX_OP(double, double, softmax_f64) RMSNORM_OP(float, rmsnorm_f32) RMSNORM_OP(double, rmsnorm_f64) -ROPE_OP(float, rope_f32, rope_i_f32) -ROPE_OP(double, rope_f64, rope_i_f64) +ROPE_OP(float, rope_f32, rope_i_f32, rope_thd_f32) +ROPE_OP(double, rope_f64, rope_i_f64, rope_thd_f64) FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32) FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64) |