diff options
Diffstat (limited to 'candle-metal-kernels/src/unary.metal')
-rw-r--r-- | candle-metal-kernels/src/unary.metal | 17 |
1 files changed, 15 insertions, 2 deletions
diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index ec793eae..143e9500 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -68,6 +68,8 @@ template <typename T> METAL_FUNC T silu(T in){ return in / (static_cast<T>(1) + exp(-in)); } +#define TILE_SIZE 2 + #define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \ kernel void FN_NAME( \ constant size_t &dim, \ @@ -79,8 +81,8 @@ kernel void FN_NAME( \ return; \ } \ output[tid] = TYPENAME(FN(float(input[tid]))); \ -}\ -kernel void FN_NAME_STRIDED( \ +} \ +kernel void FN_NAME##_##strided( \ constant size_t &dim, \ constant size_t &num_dims, \ constant size_t *dims, \ @@ -93,6 +95,17 @@ kernel void FN_NAME_STRIDED( \ return; \ } \ output[tid] = TYPENAME(FN(float(input[get_strided_index(tid, num_dims, dims, strides)]))); \ +} \ +kernel void FN_NAME##_##tiled( \ + constant size_t &dim, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + for (uint i = 0; i < TILE_SIZE; i++) { \ + const uint idx = tid * TILE_SIZE + i; \ + output[idx] = TYPENAME(FN(float(input[idx]))); \ + } \ } #define UNARY_OP(NAME) \ |