summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorThomas Santerre <thomas@santerre.xyz>2024-04-04 02:13:12 -0400
committerGitHub <noreply@github.com>2024-04-04 08:13:12 +0200
commitbd8db2a7712e14ea76a80475905db04bbf402aa6 (patch)
tree9de047145a4f4e9a44fd13e6cf11d88a976d5178 /candle-metal-kernels
parent318d143224805e490d396874b9e1aaf28991393c (diff)
downloadcandle-bd8db2a7712e14ea76a80475905db04bbf402aa6.tar.gz
candle-bd8db2a7712e14ea76a80475905db04bbf402aa6.tar.bz2
candle-bd8db2a7712e14ea76a80475905db04bbf402aa6.zip
refactor to reduce the amount of code wrapped in template syntax (#2002)
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/reduce.metal629
1 files changed, 368 insertions, 261 deletions
diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal
index be5a0921..d06efbf2 100644
--- a/candle-metal-kernels/src/reduce.metal
+++ b/candle-metal-kernels/src/reduce.metal
@@ -21,6 +21,59 @@ METAL_FUNC uint get_strided_index(
constant int THREADGROUP_SIZE = 2048;
+template<typename T>
+METAL_FUNC void argmin(
+ constant size_t &num_dims,
+ constant size_t *dims,
+ constant size_t *strides,
+ constant size_t &el_to_sum_per_block,
+ device const T *src,
+ device uint *dst,
+ uint id,
+ uint tid,
+ uint dst_id,
+ uint block_dim,
+ threadgroup T *shared_memory,
+ threadgroup uint *shared_indices
+) {
+ bool notset = true;
+ /*
+ // Elements summed in this block range from dst_id * el_to_sum_per_block
+ // to (dst_id + 1) * el_to_sum_per_block.
+ */
+ size_t start_idx = dst_id * el_to_sum_per_block;
+ size_t stop_idx = start_idx + el_to_sum_per_block;
+ size_t idx = start_idx + tid;
+ while (idx < stop_idx) {
+ /*
+ // TODO: Fast version for the contiguous case.
+ */
+ size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
+ if (notset || src[strided_i] < shared_memory[tid]) {
+ shared_memory[tid] = src[strided_i];
+ /* Assume that the reduction takes place over the last dimension which is contiguous. */
+ shared_indices[tid] = idx % dims[num_dims - 1];
+ notset = false;
+ }
+ idx += block_dim;
+ }
+
+ threadgroup_barrier(mem_flags::mem_none);
+ /*
+ // reduction in shared memory
+ */
+ for (uint s = block_dim / 2; s > 0; s >>= 1) {
+ if (tid < s && shared_memory[tid + s] < shared_memory[tid]) {
+ shared_indices[tid] = shared_indices[tid + s];
+ shared_memory[tid] = shared_memory[tid + s];
+ } \
+ threadgroup_barrier(mem_flags::mem_none);
+ }
+
+ if (tid == 0){
+ dst[dst_id] = shared_indices[0];
+ }
+}
#define ARGMIN(NAME, T, MAXVALUE) \
kernel void NAME( \
@@ -35,53 +88,71 @@ kernel void NAME( \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
- \
- threadgroup T shared_memory[THREADGROUP_SIZE]; \
- threadgroup uint shared_indices[THREADGROUP_SIZE]; \
- \
- shared_memory[tid] = MAXVALUE; \
- shared_indices[tid] = 0xFFFFFFFF; \
- bool notset = true; \
- /* \
- // Elements summed in this block range from dst_id * el_to_sum_per_block \
- // to (dst_id + 1) * el_to_sum_per_block. \
- */ \
- size_t start_idx = dst_id * el_to_sum_per_block; \
- size_t stop_idx = start_idx + el_to_sum_per_block; \
- size_t idx = start_idx + tid; \
- while (idx < stop_idx) { \
- /* \
- // TODO: Fast version for the contiguous case. \
- */ \
- size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
- if (notset || src[strided_i] < shared_memory[tid]) { \
- shared_memory[tid] = src[strided_i]; \
- /* Assume that the reduction takes place over the last dimension which is contiguous. */ \
- shared_indices[tid] = idx % dims[num_dims - 1]; \
- notset = false; \
- } \
- idx += block_dim; \
- } \
- \
- threadgroup_barrier(mem_flags::mem_none); \
- \
- /* \
- // reduction in shared memory \
- */ \
- for (uint s = block_dim / 2; s > 0; s >>= 1) { \
- if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \
- shared_indices[tid] = shared_indices[tid + s]; \
- shared_memory[tid] = shared_memory[tid + s]; \
- } \
- threadgroup_barrier(mem_flags::mem_none); \
- } \
- \
- if (tid == 0){ \
- dst[dst_id] = shared_indices[0]; \
- } \
+ threadgroup T shared_memory[THREADGROUP_SIZE]; \
+ threadgroup uint shared_indices[THREADGROUP_SIZE]; \
+ shared_memory[tid] = MAXVALUE; \
+ shared_indices[tid] = 0xFFFFFFFF; \
+ argmin<T>(num_dims, dims, strides, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory, shared_indices); \
} \
+template<typename T>
+METAL_FUNC void argmax(
+ constant size_t & num_dims,
+ constant size_t * dims,
+ constant size_t * strides,
+ constant size_t & el_to_sum_per_block,
+ device const T * src,
+ device uint * dst,
+ uint id,
+ uint tid,
+ uint dst_id,
+ uint block_dim,
+ threadgroup T * shared_memory,
+ threadgroup uint * shared_indices
+ ) {
+ /*
+ // Elements summed in this block range from dst_id * el_to_sum_per_block
+ // to (dst_id + 1) * el_to_sum_per_block.
+ */
+ size_t start_idx = dst_id * el_to_sum_per_block;
+ size_t stop_idx = start_idx + el_to_sum_per_block;
+ size_t idx = start_idx + tid;
+ bool notset = true;
+ while (idx < stop_idx) {
+ /*
+ // TODO: Fast version for the contiguous case.
+ */
+ size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
+ if (notset || shared_memory[tid] < src[strided_i]) {
+ shared_memory[tid] = src[strided_i];
+ shared_indices[tid] = idx % dims[num_dims - 1];
+ notset = false;
+ }
+ idx += block_dim;
+ }
+
+ threadgroup_barrier(mem_flags::mem_none);
+
+ /*
+ // reduction in shared memory
+ */
+ for (uint s = block_dim / 2; s > 0; s >>= 1) {
+ if (tid < s && shared_memory[tid + s] > shared_memory[tid]) {
+ shared_indices[tid] = shared_indices[tid + s];
+ shared_memory[tid] = shared_memory[tid + s];
+ }
+ threadgroup_barrier(mem_flags::mem_none);
+ }
+
+ /*
+ // Thread 0 writes the result of the reduction
+ */
+ if (tid == 0) {
+ dst[dst_id] = shared_indices[0];
+ }
+ }
+
#define ARGMAX(NAME, T, MINVALUE) \
kernel void NAME( \
constant size_t &num_dims, \
@@ -95,223 +166,279 @@ kernel void NAME( \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
- \
threadgroup T shared_memory[THREADGROUP_SIZE]; \
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
- \
shared_memory[tid] = MINVALUE; \
shared_indices[tid] = 0xFFFFFFFF; \
- /* \
- // Elements summed in this block range from dst_id * el_to_sum_per_block \
- // to (dst_id + 1) * el_to_sum_per_block. \
- */ \
- size_t start_idx = dst_id * el_to_sum_per_block; \
- size_t stop_idx = start_idx + el_to_sum_per_block; \
- size_t idx = start_idx + tid; \
- bool notset = true; \
- while (idx < stop_idx) { \
- /* \
- // TODO: Fast version for the contiguous case. \
- */ \
- size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
- if (notset || shared_memory[tid] < src[strided_i]) { \
- shared_memory[tid] = src[strided_i]; \
- shared_indices[tid] = idx % dims[num_dims - 1]; \
- notset = false; \
- } \
- idx += block_dim; \
- } \
- \
- threadgroup_barrier(mem_flags::mem_none); \
- \
- /* \
- // reduction in shared memory \
- */ \
- for (uint s = block_dim / 2; s > 0; s >>= 1) { \
- if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \
- shared_indices[tid] = shared_indices[tid + s]; \
- shared_memory[tid] = shared_memory[tid + s]; \
- } \
- threadgroup_barrier(mem_flags::mem_none); \
- } \
- \
- if (tid == 0){ \
- dst[dst_id] = shared_indices[0]; \
- } \
+ argmax<T>(num_dims, dims, strides, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory, shared_indices); \
} \
+template<typename T>
+METAL_FUNC void reduce(
+ constant size_t & num_dims,
+ constant size_t * dims,
+ constant size_t * strides,
+ constant size_t & el_to_sum_per_block,
+ device const T * src,
+ device T * dst,
+ uint id,
+ uint tid,
+ uint dst_id,
+ uint block_dim,
+ threadgroup T * shared_memory,
+ T (*fn)(T, T)
+) {
+ /*
+ // Elements summed in this block range from dst_id * el_to_sum_per_block
+ // to (dst_id + 1) * el_to_sum_per_block.
+ */
+ size_t start_idx = dst_id * el_to_sum_per_block;
+ size_t stop_idx = start_idx + el_to_sum_per_block;
+ size_t idx = start_idx + tid;
+ while (idx < stop_idx) {
+ /*
+ // TODO: Fast version for the contiguous case.
+ */
+ size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
+ T x = shared_memory[tid];
+ T y = src[strided_i];
+ shared_memory[tid] = fn(x, y);
+ idx += block_dim;
+ }
+
+ threadgroup_barrier(mem_flags::mem_none);
+
+ /*
+ // reduction in shared memory
+ */
+ for (uint s = block_dim / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ T x = shared_memory[tid];
+ T y = shared_memory[tid + s];
+ shared_memory[tid] = fn(x, y);
+ }
+ threadgroup_barrier(mem_flags::mem_none);
+ }
+
+ if (tid == 0) {
+ dst[dst_id] = shared_memory[0];
+ }
+}
+
#define REDUCE(FN, NAME, T, START) \
+METAL_FUNC T NAME##_##op(T x, T y) { return FN; } \
kernel void NAME( \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t &el_to_sum_per_block, \
- device const T *src, \
+ device const T *src, \
device T *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
- \
- threadgroup T shared_memory[THREADGROUP_SIZE]; \
- \
- shared_memory[tid] = START; \
- /* \
- // Elements summed in this block range from dst_id * el_to_sum_per_block \
- // to (dst_id + 1) * el_to_sum_per_block. \
- */ \
- size_t start_idx = dst_id * el_to_sum_per_block; \
- size_t stop_idx = start_idx + el_to_sum_per_block; \
- size_t idx = start_idx + tid; \
- while (idx < stop_idx) { \
- /* \
- // TODO: Fast version for the contiguous case. \
- */ \
- size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
- T x = shared_memory[tid]; \
- T y = src[strided_i]; \
- shared_memory[tid] = FN; \
- idx += block_dim; \
- } \
- \
- threadgroup_barrier(mem_flags::mem_none); \
- \
- /* \
- // reduction in shared memory \
- */ \
- for (uint s = block_dim / 2; s > 0; s >>= 1) { \
- if (tid < s) { \
- T x = shared_memory[tid]; \
- T y = shared_memory[tid + s]; \
- shared_memory[tid] = FN; \
- } \
- threadgroup_barrier(mem_flags::mem_none); \
- } \
+ threadgroup T shared_memory[THREADGROUP_SIZE]; \
+ shared_memory[tid] = START; \
+ reduce<T>(num_dims, dims, strides, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory, NAME##_##op); \
+} \
+
+template<typename T>
+METAL_FUNC void softmax(
+ constant size_t & src_numel,
+ constant size_t & el_to_sum_per_block,
+ device const T * src,
+ device T * dst,
+ uint id,
+ uint tid,
+ uint dst_id,
+ uint block_dim,
+ threadgroup float * shared_memory
+) {
+ size_t start_idx = dst_id * el_to_sum_per_block;
+ size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
+ size_t idx = start_idx + tid;
+
+ float tmp = -INFINITY;
+ while (idx < stop_idx) {
+ tmp = MAX(tmp, float(src[idx]));
+ idx += block_dim;
+ }
+ shared_memory[tid] = tmp;
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (uint s = block_dim / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]);\
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+
+ /* wait for shared_memory[0] to be filled */
\
- dst[dst_id] = shared_memory[0]; \
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ float _max = shared_memory[0];
+
+ /* prevent tid=0 from overwriting _max before other threads have written it */
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ shared_memory[tid] = 0;
+
+ idx = start_idx + tid;
+ while (idx < stop_idx) {
+ const float val = exp(float(src[idx]) - _max);
+ dst[idx] = T(val);
+ shared_memory[tid] += val;
+ idx += block_dim;
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ for (uint s = block_dim / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ shared_memory[tid] += shared_memory[tid + s];
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+
+ const T inv_acc = T(1.0 / shared_memory[0]);
+ idx = start_idx + tid;
+ while (idx < stop_idx) {
+ dst[idx] *= inv_acc;
+ idx += block_dim;
+ }
+}
+
+#define SOFTMAX(NAME, T) \
+kernel void NAME( \
+ constant size_t &src_numel, \
+ constant size_t &el_to_sum_per_block, \
+ device const T *src, \
+ device T *dst, \
+ uint id [[ thread_position_in_grid ]], \
+ uint tid [[ thread_index_in_threadgroup ]], \
+ uint dst_id [[ threadgroup_position_in_grid ]], \
+ uint block_dim [[ threads_per_threadgroup ]] \
+) { \
+ threadgroup float shared_memory[THREADGROUP_SIZE]; \
+ shared_memory[tid] = -INFINITY; \
+ softmax<T>(src_numel, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory); \
+} \
+
+template<typename T>
+METAL_FUNC void rmsnorm(
+ constant size_t & src_numel,
+ constant size_t & el_to_sum_per_block,
+ device const T * src,
+ device T * dst,
+ device const T * alpha,
+ constant float & eps,
+ uint id,
+ uint tid,
+ uint dst_id,
+ uint block_dim,
+ threadgroup float * shared_memory
+) {
+ size_t start_idx = dst_id * el_to_sum_per_block;
+ size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
+ size_t idx = start_idx + tid;
+
+ float tmp = 0;
+ while (idx < stop_idx) {
+ tmp = tmp + float(src[idx]) * float(src[idx]);
+ idx += block_dim;
+ }
+ shared_memory[tid] = tmp;
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (uint s = block_dim / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ shared_memory[tid] = shared_memory[tid] + shared_memory[tid + s];
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+
+ /* wait for shared_memory[0] to be filled */
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ float norm = sqrt(shared_memory[0] / float(el_to_sum_per_block) + eps);
+ float inv_norm = 1.0f / norm;
+ idx = start_idx + tid;
+ while (idx < stop_idx) {
+ float val = float(src[idx]) * inv_norm;
+ if (alpha != nullptr) {
+ val *= float(alpha[idx - start_idx]);
+ }
+ dst[idx] = T(val);
+ idx += block_dim;
+ }
+}
+
+#define RMSNORM(NAME, T) \
+kernel void NAME( \
+ constant size_t &src_numel, \
+ constant size_t &el_to_sum_per_block, \
+ device const T *src, \
+ device T *dst, \
+ device const T *alpha, \
+ constant float &eps, \
+ uint id [[ thread_position_in_grid ]], \
+ uint tid [[ thread_index_in_threadgroup ]], \
+ uint dst_id [[ threadgroup_position_in_grid ]], \
+ uint block_dim [[ threads_per_threadgroup ]] \
+) { \
+ threadgroup float shared_memory[THREADGROUP_SIZE]; \
+ shared_memory[tid] = 0; \
+ rmsnorm<T>(src_numel, el_to_sum_per_block, src, dst, alpha, eps, id, tid, dst_id, block_dim, shared_memory); \
} \
+template<typename T>
+METAL_FUNC void ropei(
+ constant size_t &bh,
+ constant size_t &td,
+ device const T *src,
+ device const T *cos,
+ device const T *sin,
+ device T *dst,
+ uint tid
+) {
+ if (2 * tid >= bh * td) {
+ return;
+ }
+ size_t rope_idx = tid % (td / 2);
+ T c = cos[rope_idx];
+ T s = sin[rope_idx];
+ dst[2 * tid] = src[2 * tid] * c - src[2 * tid + 1] * s;
+ dst[2 * tid + 1] = src[2 * tid] * s + src[2 * tid + 1] * c;
+}
-#define SOFTMAX(NAME, T) \
-kernel void NAME( \
- constant size_t &src_numel, \
- constant size_t &el_to_sum_per_block, \
- device const T *src, \
- device T *dst, \
- \
- uint id [[ thread_position_in_grid ]], \
- uint tid [[ thread_index_in_threadgroup ]], \
- uint dst_id [[ threadgroup_position_in_grid ]], \
- uint block_dim [[ threads_per_threadgroup ]] \
-) { \
- threadgroup float shared_memory[THREADGROUP_SIZE]; \
- shared_memory[tid] = -INFINITY; \
- size_t start_idx = dst_id * el_to_sum_per_block; \
- size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
- size_t idx = start_idx + tid; \
- \
- \
- float tmp = -INFINITY; \
- while (idx < stop_idx) { \
- tmp = MAX(tmp, float(src[idx])); \
- idx += block_dim; \
- } \
- shared_memory[tid] = tmp; \
- \
- threadgroup_barrier(mem_flags::mem_threadgroup); \
- \
- for (uint s = block_dim / 2; s > 0; s >>= 1) { \
- if (tid < s) { \
- shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
- } \
- threadgroup_barrier(mem_flags::mem_threadgroup); \
- } \
- \
- /* wait for shared_memory[0] to be filled */ \
- threadgroup_barrier(mem_flags::mem_threadgroup); \
- \
- float _max = shared_memory[0]; \
- \
- /* prevent tid=0 from overwriting _max before other threads have written it */ \
- threadgroup_barrier(mem_flags::mem_threadgroup); \
- shared_memory[tid] = 0; \
- \
- idx = start_idx + tid; \
- while (idx < stop_idx) { \
- const float val = exp(float(src[idx]) - _max); \
- dst[idx] = T(val); \
- shared_memory[tid] += val; \
- idx += block_dim; \
- } \
- threadgroup_barrier(mem_flags::mem_threadgroup); \
- for (uint s = block_dim / 2; s > 0; s >>= 1) { \
- if (tid < s) { \
- shared_memory[tid] += shared_memory[tid + s]; \
- } \
- threadgroup_barrier(mem_flags::mem_threadgroup); \
- } \
- \
- const T inv_acc = T(1.0/shared_memory[0]); \
- idx = start_idx + tid; \
- while (idx < stop_idx) { \
- dst[idx] *= inv_acc; \
- idx += block_dim; \
- } \
-} \
-
-#define RMSNORM(NAME, T) \
-kernel void NAME( \
- constant size_t &src_numel, \
- constant size_t &el_to_sum_per_block, \
- device const T *src, \
- device T *dst, \
- device const T *alpha, \
- constant float &eps, \
- \
- uint id [[ thread_position_in_grid ]], \
- uint tid [[ thread_index_in_threadgroup ]], \
- uint dst_id [[ threadgroup_position_in_grid ]], \
- uint block_dim [[ threads_per_threadgroup ]] \
-) { \
- threadgroup float shared_memory[THREADGROUP_SIZE]; \
- shared_memory[tid] = 0; \
- size_t start_idx = dst_id * el_to_sum_per_block; \
- size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
- size_t idx = start_idx + tid; \
- \
- \
- float tmp = 0; \
- while (idx < stop_idx) { \
- tmp = tmp + float(src[idx]) * float(src[idx]); \
- idx += block_dim; \
- } \
- shared_memory[tid] = tmp; \
- \
- threadgroup_barrier(mem_flags::mem_threadgroup); \
- \
- for (uint s = block_dim / 2; s > 0; s >>= 1) { \
- if (tid < s) { \
- shared_memory[tid] = shared_memory[tid] + shared_memory[tid + s]; \
- } \
- threadgroup_barrier(mem_flags::mem_threadgroup); \
- } \
- \
- /* wait for shared_memory[0] to be filled */ \
- threadgroup_barrier(mem_flags::mem_threadgroup); \
- \
- float norm = sqrt(shared_memory[0] / float(el_to_sum_per_block) + eps); \
- float inv_norm = 1.0f / norm; \
- idx = start_idx + tid; \
- while (idx < stop_idx) { \
- float val = float(src[idx]) * inv_norm; \
- if (alpha != nullptr) { \
- val *= float(alpha[idx - start_idx]); \
- } \
- dst[idx] = T(val); \
- idx += block_dim; \
- } \
-} \
+template<typename T>
+METAL_FUNC void rope(
+ constant size_t &bh,
+ constant size_t &td,
+ constant size_t &d,
+ device const T *src,
+ device const T *cos,
+ device const T *sin,
+ device T *dst,
+ uint idx
+) {
+ if (2 * idx >= bh * td) {
+ return;
+ }
+ size_t i_bh = idx / (td / 2);
+ size_t i_td = idx - (td / 2) * i_bh;
+ size_t i_t = i_td / (d / 2);
+ size_t i_d = i_td - (d / 2) * i_t;
+ size_t i1 = i_bh * td + i_t * d + i_d;
+ size_t i2 = i1 + d / 2;
+ size_t i_cs = i_t * (d / 2) + i_d;
+ T c = cos[i_cs];
+ T s = sin[i_cs];
+ dst[i1] = src[i1] * c - src[i2] * s;
+ dst[i2] = src[i1] * s + src[i2] * c;
+}
#define ROPEI(FN_NAME, FN_NAME_I, TYPENAME) \
kernel void FN_NAME_I( \
@@ -323,14 +450,7 @@ kernel void FN_NAME_I( \
device TYPENAME *dst, \
uint tid [[ thread_position_in_grid ]] \
) { \
- if (2 * tid >= bh * td) { \
- return; \
- } \
- size_t rope_idx = tid % (td / 2); \
- TYPENAME c = cos[rope_idx]; \
- TYPENAME s = sin[rope_idx]; \
- dst[2 * tid] = src[2 * tid] * c - src[2 * tid + 1] * s; \
- dst[2 * tid + 1] = src[2 * tid] * s + src[2 * tid + 1] * c; \
+ ropei<TYPENAME>(bh, td, src, cos, sin, dst, tid); \
}\
kernel void FN_NAME( \
constant size_t &bh, \
@@ -342,20 +462,7 @@ kernel void FN_NAME( \
device TYPENAME *dst, \
uint idx [[ thread_position_in_grid ]] \
) { \
- if (2 * idx >= bh * td) { \
- return; \
- } \
- size_t i_bh = idx / (td / 2); \
- size_t i_td = idx - (td / 2) * i_bh; \
- size_t i_t = i_td / (d / 2); \
- size_t i_d = i_td - (d / 2) * i_t; \
- size_t i1 = i_bh * td + i_t * d + i_d; \
- size_t i2 = i1 + d / 2; \
- size_t i_cs = i_t * (d / 2) + i_d; \
- TYPENAME c = cos[i_cs]; \
- TYPENAME s = sin[i_cs]; \
- dst[i1] = src[i1] * c - src[i2] * s; \
- dst[i2] = src[i1] * s + src[i2] * c; \
+ rope<TYPENAME>(bh, td, d, src, cos, sin, dst, idx); \
}\
REDUCE(x + y, fast_sum_f32_strided, float, 0)