summaryrefslogtreecommitdiff
path: root/candle-kernels
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-05 08:32:58 +0200
committerGitHub <noreply@github.com>2024-04-05 08:32:58 +0200
commit2ac302a5d170953a1d2fe850645563fc55d1567f (patch)
treeb35b32efe5c8eac25a9b5681fb0778ef84e57d0e /candle-kernels
parentace282e5c2ef24ca2fb90683babb852936d4df17 (diff)
downloadcandle-2ac302a5d170953a1d2fe850645563fc55d1567f.tar.gz
candle-2ac302a5d170953a1d2fe850645563fc55d1567f.tar.bz2
candle-2ac302a5d170953a1d2fe850645563fc55d1567f.zip
Add the rope THD kernel. (#2014)
* Add the rope THD kernel. * Cuda kernel for rope-thd. * Add the metal kernels. * Add a dedicated test.
Diffstat (limited to 'candle-kernels')
-rw-r--r--candle-kernels/src/reduce.cu48
1 files changed, 43 insertions, 5 deletions
diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu
index 2af81c42..4dbd8dcc 100644
--- a/candle-kernels/src/reduce.cu
+++ b/candle-kernels/src/reduce.cu
@@ -180,6 +180,33 @@ __device__ void rope(const T * src, const T * cos, const T * sin, T * dst, const
}
template <typename T>
+__device__ void rope_thd(
+ const T * src,
+ const T * cos,
+ const T * sin,
+ T * dst,
+ const uint32_t b,
+ const uint32_t t,
+ const uint32_t h,
+ const uint32_t d
+) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (2 * idx >= b * t * h * d) return;
+
+ uint32_t i_bth = idx / (d / 2);
+ uint32_t i_d = idx - (d / 2) * i_bth;
+ uint32_t i_t = (i_bth / h) % t;
+ uint32_t i1 = i_bth * d + i_d;
+ uint32_t i2 = i1 + d / 2;
+ uint32_t i_cs = i_t * (d / 2) + i_d;
+ T c = cos[i_cs];
+ T s = sin[i_cs];
+
+ dst[i1] = src[i1] * c - src[i2] * s;
+ dst[i2] = src[i1] * s + src[i2] * c;
+}
+
+template <typename T>
__device__ void
fast_max(const size_t src_numel, const size_t el_to_sum_per_block,
const size_t num_dims, const size_t *info, const T *src, T *dst) {
@@ -434,7 +461,7 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
rmsnorm<TYPENAME>(src, dst, alpha, n_cols, eps); \
} \
-#define ROPE_OP(TYPENAME, FN_NAME, FN_NAME_I) \
+#define ROPE_OP(TYPENAME, FN_NAME, FN_NAME_I, FN_NAME_THD) \
extern "C" __global__ void FN_NAME_I( \
const TYPENAME *src, \
const TYPENAME *cos, \
@@ -454,11 +481,22 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
const uint32_t d) { \
rope<TYPENAME>(src, cos, sin, dst, bh, td, d); \
} \
+ extern "C" __global__ void FN_NAME_THD( \
+ const TYPENAME *src, \
+ const TYPENAME *cos, \
+ const TYPENAME *sin, \
+ TYPENAME *dst, \
+ const uint32_t b, \
+ const uint32_t t, \
+ const uint32_t h, \
+ const uint32_t d) { \
+ rope_thd<TYPENAME>(src, cos, sin, dst, b, t, h, d); \
+ } \
#if __CUDA_ARCH__ >= 800
SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16)
-ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16)
+ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16)
SUM_OP(__nv_bfloat16, sum_bf16)
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16)
#endif
@@ -466,7 +504,7 @@ FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argm
#if __CUDA_ARCH__ >= 530
SOFTMAX_OP(__half, float, softmax_f16)
RMSNORM_OP(__half, rmsnorm_f16)
-ROPE_OP(__half, rope_f16, rope_i_f16)
+ROPE_OP(__half, rope_f16, rope_i_f16, rope_thd_f16)
SUM_OP(__half, sum_f16)
FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16)
#endif
@@ -478,8 +516,8 @@ SOFTMAX_OP(float, float, softmax_f32)
SOFTMAX_OP(double, double, softmax_f64)
RMSNORM_OP(float, rmsnorm_f32)
RMSNORM_OP(double, rmsnorm_f64)
-ROPE_OP(float, rope_f32, rope_i_f32)
-ROPE_OP(double, rope_f64, rope_i_f64)
+ROPE_OP(float, rope_f32, rope_i_f32, rope_thd_f32)
+ROPE_OP(double, rope_f64, rope_i_f64, rope_thd_f64)
FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32)
FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64)