diff options
author | Thomas Santerre <thomas@santerre.xyz> | 2024-04-04 02:13:12 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-04 08:13:12 +0200 |
commit | bd8db2a7712e14ea76a80475905db04bbf402aa6 (patch) | |
tree | 9de047145a4f4e9a44fd13e6cf11d88a976d5178 /candle-metal-kernels | |
parent | 318d143224805e490d396874b9e1aaf28991393c (diff) | |
download | candle-bd8db2a7712e14ea76a80475905db04bbf402aa6.tar.gz candle-bd8db2a7712e14ea76a80475905db04bbf402aa6.tar.bz2 candle-bd8db2a7712e14ea76a80475905db04bbf402aa6.zip |
refactor to reduce the amount of code wrapped in template syntax (#2002)
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/reduce.metal | 629 |
1 files changed, 368 insertions, 261 deletions
diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index be5a0921..d06efbf2 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -21,6 +21,59 @@ METAL_FUNC uint get_strided_index( constant int THREADGROUP_SIZE = 2048; +template<typename T> +METAL_FUNC void argmin( + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides, + constant size_t &el_to_sum_per_block, + device const T *src, + device uint *dst, + uint id, + uint tid, + uint dst_id, + uint block_dim, + threadgroup T *shared_memory, + threadgroup uint *shared_indices +) { + bool notset = true; + /* + // Elements summed in this block range from dst_id * el_to_sum_per_block + // to (dst_id + 1) * el_to_sum_per_block. + */ + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = start_idx + el_to_sum_per_block; + size_t idx = start_idx + tid; + while (idx < stop_idx) { + /* + // TODO: Fast version for the contiguous case. + */ + size_t strided_i = get_strided_index(idx, num_dims, dims, strides); + if (notset || src[strided_i] < shared_memory[tid]) { + shared_memory[tid] = src[strided_i]; + /* Assume that the reduction takes place over the last dimension which is contiguous. */ + shared_indices[tid] = idx % dims[num_dims - 1]; + notset = false; + } + idx += block_dim; + } + + threadgroup_barrier(mem_flags::mem_none); + /* + // reduction in shared memory + */ + for (uint s = block_dim / 2; s > 0; s >>= 1) { + if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { + shared_indices[tid] = shared_indices[tid + s]; + shared_memory[tid] = shared_memory[tid + s]; + } \ + threadgroup_barrier(mem_flags::mem_none); + } + + if (tid == 0){ + dst[dst_id] = shared_indices[0]; + } +} #define ARGMIN(NAME, T, MAXVALUE) \ kernel void NAME( \ @@ -35,53 +88,71 @@ kernel void NAME( \ uint dst_id [[ threadgroup_position_in_grid ]], \ uint block_dim [[ threads_per_threadgroup ]] \ ) { \ - \ - threadgroup T shared_memory[THREADGROUP_SIZE]; \ - threadgroup uint shared_indices[THREADGROUP_SIZE]; \ - \ - shared_memory[tid] = MAXVALUE; \ - shared_indices[tid] = 0xFFFFFFFF; \ - bool notset = true; \ - /* \ - // Elements summed in this block range from dst_id * el_to_sum_per_block \ - // to (dst_id + 1) * el_to_sum_per_block. \ - */ \ - size_t start_idx = dst_id * el_to_sum_per_block; \ - size_t stop_idx = start_idx + el_to_sum_per_block; \ - size_t idx = start_idx + tid; \ - while (idx < stop_idx) { \ - /* \ - // TODO: Fast version for the contiguous case. \ - */ \ - size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \ - if (notset || src[strided_i] < shared_memory[tid]) { \ - shared_memory[tid] = src[strided_i]; \ - /* Assume that the reduction takes place over the last dimension which is contiguous. */ \ - shared_indices[tid] = idx % dims[num_dims - 1]; \ - notset = false; \ - } \ - idx += block_dim; \ - } \ - \ - threadgroup_barrier(mem_flags::mem_none); \ - \ - /* \ - // reduction in shared memory \ - */ \ - for (uint s = block_dim / 2; s > 0; s >>= 1) { \ - if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \ - shared_indices[tid] = shared_indices[tid + s]; \ - shared_memory[tid] = shared_memory[tid + s]; \ - } \ - threadgroup_barrier(mem_flags::mem_none); \ - } \ - \ - if (tid == 0){ \ - dst[dst_id] = shared_indices[0]; \ - } \ + threadgroup T shared_memory[THREADGROUP_SIZE]; \ + threadgroup uint shared_indices[THREADGROUP_SIZE]; \ + shared_memory[tid] = MAXVALUE; \ + shared_indices[tid] = 0xFFFFFFFF; \ + argmin<T>(num_dims, dims, strides, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory, shared_indices); \ } \ +template<typename T> +METAL_FUNC void argmax( + constant size_t & num_dims, + constant size_t * dims, + constant size_t * strides, + constant size_t & el_to_sum_per_block, + device const T * src, + device uint * dst, + uint id, + uint tid, + uint dst_id, + uint block_dim, + threadgroup T * shared_memory, + threadgroup uint * shared_indices + ) { + /* + // Elements summed in this block range from dst_id * el_to_sum_per_block + // to (dst_id + 1) * el_to_sum_per_block. + */ + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = start_idx + el_to_sum_per_block; + size_t idx = start_idx + tid; + bool notset = true; + while (idx < stop_idx) { + /* + // TODO: Fast version for the contiguous case. + */ + size_t strided_i = get_strided_index(idx, num_dims, dims, strides); + if (notset || shared_memory[tid] < src[strided_i]) { + shared_memory[tid] = src[strided_i]; + shared_indices[tid] = idx % dims[num_dims - 1]; + notset = false; + } + idx += block_dim; + } + + threadgroup_barrier(mem_flags::mem_none); + + /* + // reduction in shared memory + */ + for (uint s = block_dim / 2; s > 0; s >>= 1) { + if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { + shared_indices[tid] = shared_indices[tid + s]; + shared_memory[tid] = shared_memory[tid + s]; + } + threadgroup_barrier(mem_flags::mem_none); + } + + /* + // Thread 0 writes the result of the reduction + */ + if (tid == 0) { + dst[dst_id] = shared_indices[0]; + } + } + #define ARGMAX(NAME, T, MINVALUE) \ kernel void NAME( \ constant size_t &num_dims, \ @@ -95,223 +166,279 @@ kernel void NAME( \ uint dst_id [[ threadgroup_position_in_grid ]], \ uint block_dim [[ threads_per_threadgroup ]] \ ) { \ - \ threadgroup T shared_memory[THREADGROUP_SIZE]; \ threadgroup uint shared_indices[THREADGROUP_SIZE]; \ - \ shared_memory[tid] = MINVALUE; \ shared_indices[tid] = 0xFFFFFFFF; \ - /* \ - // Elements summed in this block range from dst_id * el_to_sum_per_block \ - // to (dst_id + 1) * el_to_sum_per_block. \ - */ \ - size_t start_idx = dst_id * el_to_sum_per_block; \ - size_t stop_idx = start_idx + el_to_sum_per_block; \ - size_t idx = start_idx + tid; \ - bool notset = true; \ - while (idx < stop_idx) { \ - /* \ - // TODO: Fast version for the contiguous case. \ - */ \ - size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \ - if (notset || shared_memory[tid] < src[strided_i]) { \ - shared_memory[tid] = src[strided_i]; \ - shared_indices[tid] = idx % dims[num_dims - 1]; \ - notset = false; \ - } \ - idx += block_dim; \ - } \ - \ - threadgroup_barrier(mem_flags::mem_none); \ - \ - /* \ - // reduction in shared memory \ - */ \ - for (uint s = block_dim / 2; s > 0; s >>= 1) { \ - if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \ - shared_indices[tid] = shared_indices[tid + s]; \ - shared_memory[tid] = shared_memory[tid + s]; \ - } \ - threadgroup_barrier(mem_flags::mem_none); \ - } \ - \ - if (tid == 0){ \ - dst[dst_id] = shared_indices[0]; \ - } \ + argmax<T>(num_dims, dims, strides, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory, shared_indices); \ } \ +template<typename T> +METAL_FUNC void reduce( + constant size_t & num_dims, + constant size_t * dims, + constant size_t * strides, + constant size_t & el_to_sum_per_block, + device const T * src, + device T * dst, + uint id, + uint tid, + uint dst_id, + uint block_dim, + threadgroup T * shared_memory, + T (*fn)(T, T) +) { + /* + // Elements summed in this block range from dst_id * el_to_sum_per_block + // to (dst_id + 1) * el_to_sum_per_block. + */ + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = start_idx + el_to_sum_per_block; + size_t idx = start_idx + tid; + while (idx < stop_idx) { + /* + // TODO: Fast version for the contiguous case. + */ + size_t strided_i = get_strided_index(idx, num_dims, dims, strides); + T x = shared_memory[tid]; + T y = src[strided_i]; + shared_memory[tid] = fn(x, y); + idx += block_dim; + } + + threadgroup_barrier(mem_flags::mem_none); + + /* + // reduction in shared memory + */ + for (uint s = block_dim / 2; s > 0; s >>= 1) { + if (tid < s) { + T x = shared_memory[tid]; + T y = shared_memory[tid + s]; + shared_memory[tid] = fn(x, y); + } + threadgroup_barrier(mem_flags::mem_none); + } + + if (tid == 0) { + dst[dst_id] = shared_memory[0]; + } +} + #define REDUCE(FN, NAME, T, START) \ +METAL_FUNC T NAME##_##op(T x, T y) { return FN; } \ kernel void NAME( \ constant size_t &num_dims, \ constant size_t *dims, \ constant size_t *strides, \ constant size_t &el_to_sum_per_block, \ - device const T *src, \ + device const T *src, \ device T *dst, \ uint id [[ thread_position_in_grid ]], \ uint tid [[ thread_index_in_threadgroup ]], \ uint dst_id [[ threadgroup_position_in_grid ]], \ uint block_dim [[ threads_per_threadgroup ]] \ ) { \ - \ - threadgroup T shared_memory[THREADGROUP_SIZE]; \ - \ - shared_memory[tid] = START; \ - /* \ - // Elements summed in this block range from dst_id * el_to_sum_per_block \ - // to (dst_id + 1) * el_to_sum_per_block. \ - */ \ - size_t start_idx = dst_id * el_to_sum_per_block; \ - size_t stop_idx = start_idx + el_to_sum_per_block; \ - size_t idx = start_idx + tid; \ - while (idx < stop_idx) { \ - /* \ - // TODO: Fast version for the contiguous case. \ - */ \ - size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \ - T x = shared_memory[tid]; \ - T y = src[strided_i]; \ - shared_memory[tid] = FN; \ - idx += block_dim; \ - } \ - \ - threadgroup_barrier(mem_flags::mem_none); \ - \ - /* \ - // reduction in shared memory \ - */ \ - for (uint s = block_dim / 2; s > 0; s >>= 1) { \ - if (tid < s) { \ - T x = shared_memory[tid]; \ - T y = shared_memory[tid + s]; \ - shared_memory[tid] = FN; \ - } \ - threadgroup_barrier(mem_flags::mem_none); \ - } \ + threadgroup T shared_memory[THREADGROUP_SIZE]; \ + shared_memory[tid] = START; \ + reduce<T>(num_dims, dims, strides, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory, NAME##_##op); \ +} \ + +template<typename T> +METAL_FUNC void softmax( + constant size_t & src_numel, + constant size_t & el_to_sum_per_block, + device const T * src, + device T * dst, + uint id, + uint tid, + uint dst_id, + uint block_dim, + threadgroup float * shared_memory +) { + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); + size_t idx = start_idx + tid; + + float tmp = -INFINITY; + while (idx < stop_idx) { + tmp = MAX(tmp, float(src[idx])); + idx += block_dim; + } + shared_memory[tid] = tmp; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint s = block_dim / 2; s > 0; s >>= 1) { + if (tid < s) { + shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]);\ + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + /* wait for shared_memory[0] to be filled */ \ - dst[dst_id] = shared_memory[0]; \ + threadgroup_barrier(mem_flags::mem_threadgroup); + + float _max = shared_memory[0]; + + /* prevent tid=0 from overwriting _max before other threads have written it */ + threadgroup_barrier(mem_flags::mem_threadgroup); + shared_memory[tid] = 0; + + idx = start_idx + tid; + while (idx < stop_idx) { + const float val = exp(float(src[idx]) - _max); + dst[idx] = T(val); + shared_memory[tid] += val; + idx += block_dim; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint s = block_dim / 2; s > 0; s >>= 1) { + if (tid < s) { + shared_memory[tid] += shared_memory[tid + s]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + const T inv_acc = T(1.0 / shared_memory[0]); + idx = start_idx + tid; + while (idx < stop_idx) { + dst[idx] *= inv_acc; + idx += block_dim; + } +} + +#define SOFTMAX(NAME, T) \ +kernel void NAME( \ + constant size_t &src_numel, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device T *dst, \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + threadgroup float shared_memory[THREADGROUP_SIZE]; \ + shared_memory[tid] = -INFINITY; \ + softmax<T>(src_numel, el_to_sum_per_block, src, dst, id, tid, dst_id, block_dim, shared_memory); \ +} \ + +template<typename T> +METAL_FUNC void rmsnorm( + constant size_t & src_numel, + constant size_t & el_to_sum_per_block, + device const T * src, + device T * dst, + device const T * alpha, + constant float & eps, + uint id, + uint tid, + uint dst_id, + uint block_dim, + threadgroup float * shared_memory +) { + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); + size_t idx = start_idx + tid; + + float tmp = 0; + while (idx < stop_idx) { + tmp = tmp + float(src[idx]) * float(src[idx]); + idx += block_dim; + } + shared_memory[tid] = tmp; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint s = block_dim / 2; s > 0; s >>= 1) { + if (tid < s) { + shared_memory[tid] = shared_memory[tid] + shared_memory[tid + s]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + /* wait for shared_memory[0] to be filled */ + threadgroup_barrier(mem_flags::mem_threadgroup); + + float norm = sqrt(shared_memory[0] / float(el_to_sum_per_block) + eps); + float inv_norm = 1.0f / norm; + idx = start_idx + tid; + while (idx < stop_idx) { + float val = float(src[idx]) * inv_norm; + if (alpha != nullptr) { + val *= float(alpha[idx - start_idx]); + } + dst[idx] = T(val); + idx += block_dim; + } +} + +#define RMSNORM(NAME, T) \ +kernel void NAME( \ + constant size_t &src_numel, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device T *dst, \ + device const T *alpha, \ + constant float &eps, \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + threadgroup float shared_memory[THREADGROUP_SIZE]; \ + shared_memory[tid] = 0; \ + rmsnorm<T>(src_numel, el_to_sum_per_block, src, dst, alpha, eps, id, tid, dst_id, block_dim, shared_memory); \ } \ +template<typename T> +METAL_FUNC void ropei( + constant size_t &bh, + constant size_t &td, + device const T *src, + device const T *cos, + device const T *sin, + device T *dst, + uint tid +) { + if (2 * tid >= bh * td) { + return; + } + size_t rope_idx = tid % (td / 2); + T c = cos[rope_idx]; + T s = sin[rope_idx]; + dst[2 * tid] = src[2 * tid] * c - src[2 * tid + 1] * s; + dst[2 * tid + 1] = src[2 * tid] * s + src[2 * tid + 1] * c; +} -#define SOFTMAX(NAME, T) \ -kernel void NAME( \ - constant size_t &src_numel, \ - constant size_t &el_to_sum_per_block, \ - device const T *src, \ - device T *dst, \ - \ - uint id [[ thread_position_in_grid ]], \ - uint tid [[ thread_index_in_threadgroup ]], \ - uint dst_id [[ threadgroup_position_in_grid ]], \ - uint block_dim [[ threads_per_threadgroup ]] \ -) { \ - threadgroup float shared_memory[THREADGROUP_SIZE]; \ - shared_memory[tid] = -INFINITY; \ - size_t start_idx = dst_id * el_to_sum_per_block; \ - size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \ - size_t idx = start_idx + tid; \ - \ - \ - float tmp = -INFINITY; \ - while (idx < stop_idx) { \ - tmp = MAX(tmp, float(src[idx])); \ - idx += block_dim; \ - } \ - shared_memory[tid] = tmp; \ - \ - threadgroup_barrier(mem_flags::mem_threadgroup); \ - \ - for (uint s = block_dim / 2; s > 0; s >>= 1) { \ - if (tid < s) { \ - shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \ - } \ - threadgroup_barrier(mem_flags::mem_threadgroup); \ - } \ - \ - /* wait for shared_memory[0] to be filled */ \ - threadgroup_barrier(mem_flags::mem_threadgroup); \ - \ - float _max = shared_memory[0]; \ - \ - /* prevent tid=0 from overwriting _max before other threads have written it */ \ - threadgroup_barrier(mem_flags::mem_threadgroup); \ - shared_memory[tid] = 0; \ - \ - idx = start_idx + tid; \ - while (idx < stop_idx) { \ - const float val = exp(float(src[idx]) - _max); \ - dst[idx] = T(val); \ - shared_memory[tid] += val; \ - idx += block_dim; \ - } \ - threadgroup_barrier(mem_flags::mem_threadgroup); \ - for (uint s = block_dim / 2; s > 0; s >>= 1) { \ - if (tid < s) { \ - shared_memory[tid] += shared_memory[tid + s]; \ - } \ - threadgroup_barrier(mem_flags::mem_threadgroup); \ - } \ - \ - const T inv_acc = T(1.0/shared_memory[0]); \ - idx = start_idx + tid; \ - while (idx < stop_idx) { \ - dst[idx] *= inv_acc; \ - idx += block_dim; \ - } \ -} \ - -#define RMSNORM(NAME, T) \ -kernel void NAME( \ - constant size_t &src_numel, \ - constant size_t &el_to_sum_per_block, \ - device const T *src, \ - device T *dst, \ - device const T *alpha, \ - constant float &eps, \ - \ - uint id [[ thread_position_in_grid ]], \ - uint tid [[ thread_index_in_threadgroup ]], \ - uint dst_id [[ threadgroup_position_in_grid ]], \ - uint block_dim [[ threads_per_threadgroup ]] \ -) { \ - threadgroup float shared_memory[THREADGROUP_SIZE]; \ - shared_memory[tid] = 0; \ - size_t start_idx = dst_id * el_to_sum_per_block; \ - size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \ - size_t idx = start_idx + tid; \ - \ - \ - float tmp = 0; \ - while (idx < stop_idx) { \ - tmp = tmp + float(src[idx]) * float(src[idx]); \ - idx += block_dim; \ - } \ - shared_memory[tid] = tmp; \ - \ - threadgroup_barrier(mem_flags::mem_threadgroup); \ - \ - for (uint s = block_dim / 2; s > 0; s >>= 1) { \ - if (tid < s) { \ - shared_memory[tid] = shared_memory[tid] + shared_memory[tid + s]; \ - } \ - threadgroup_barrier(mem_flags::mem_threadgroup); \ - } \ - \ - /* wait for shared_memory[0] to be filled */ \ - threadgroup_barrier(mem_flags::mem_threadgroup); \ - \ - float norm = sqrt(shared_memory[0] / float(el_to_sum_per_block) + eps); \ - float inv_norm = 1.0f / norm; \ - idx = start_idx + tid; \ - while (idx < stop_idx) { \ - float val = float(src[idx]) * inv_norm; \ - if (alpha != nullptr) { \ - val *= float(alpha[idx - start_idx]); \ - } \ - dst[idx] = T(val); \ - idx += block_dim; \ - } \ -} \ +template<typename T> +METAL_FUNC void rope( + constant size_t &bh, + constant size_t &td, + constant size_t &d, + device const T *src, + device const T *cos, + device const T *sin, + device T *dst, + uint idx +) { + if (2 * idx >= bh * td) { + return; + } + size_t i_bh = idx / (td / 2); + size_t i_td = idx - (td / 2) * i_bh; + size_t i_t = i_td / (d / 2); + size_t i_d = i_td - (d / 2) * i_t; + size_t i1 = i_bh * td + i_t * d + i_d; + size_t i2 = i1 + d / 2; + size_t i_cs = i_t * (d / 2) + i_d; + T c = cos[i_cs]; + T s = sin[i_cs]; + dst[i1] = src[i1] * c - src[i2] * s; + dst[i2] = src[i1] * s + src[i2] * c; +} #define ROPEI(FN_NAME, FN_NAME_I, TYPENAME) \ kernel void FN_NAME_I( \ @@ -323,14 +450,7 @@ kernel void FN_NAME_I( \ device TYPENAME *dst, \ uint tid [[ thread_position_in_grid ]] \ ) { \ - if (2 * tid >= bh * td) { \ - return; \ - } \ - size_t rope_idx = tid % (td / 2); \ - TYPENAME c = cos[rope_idx]; \ - TYPENAME s = sin[rope_idx]; \ - dst[2 * tid] = src[2 * tid] * c - src[2 * tid + 1] * s; \ - dst[2 * tid + 1] = src[2 * tid] * s + src[2 * tid + 1] * c; \ + ropei<TYPENAME>(bh, td, src, cos, sin, dst, tid); \ }\ kernel void FN_NAME( \ constant size_t &bh, \ @@ -342,20 +462,7 @@ kernel void FN_NAME( \ device TYPENAME *dst, \ uint idx [[ thread_position_in_grid ]] \ ) { \ - if (2 * idx >= bh * td) { \ - return; \ - } \ - size_t i_bh = idx / (td / 2); \ - size_t i_td = idx - (td / 2) * i_bh; \ - size_t i_t = i_td / (d / 2); \ - size_t i_d = i_td - (d / 2) * i_t; \ - size_t i1 = i_bh * td + i_t * d + i_d; \ - size_t i2 = i1 + d / 2; \ - size_t i_cs = i_t * (d / 2) + i_d; \ - TYPENAME c = cos[i_cs]; \ - TYPENAME s = sin[i_cs]; \ - dst[i1] = src[i1] * c - src[i2] * s; \ - dst[i2] = src[i1] * s + src[i2] * c; \ + rope<TYPENAME>(bh, td, d, src, cos, sin, dst, idx); \ }\ REDUCE(x + y, fast_sum_f32_strided, float, 0) |