summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src/unary.metal
diff options
context:
space:
mode:
Diffstat (limited to 'candle-metal-kernels/src/unary.metal')
-rw-r--r--candle-metal-kernels/src/unary.metal17
1 files changed, 15 insertions, 2 deletions
diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal
index ec793eae..143e9500 100644
--- a/candle-metal-kernels/src/unary.metal
+++ b/candle-metal-kernels/src/unary.metal
@@ -68,6 +68,8 @@ template <typename T> METAL_FUNC T silu(T in){
return in / (static_cast<T>(1) + exp(-in));
}
+#define TILE_SIZE 2
+
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \
constant size_t &dim, \
@@ -79,8 +81,8 @@ kernel void FN_NAME( \
return; \
} \
output[tid] = TYPENAME(FN(float(input[tid]))); \
-}\
-kernel void FN_NAME_STRIDED( \
+} \
+kernel void FN_NAME##_##strided( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
@@ -93,6 +95,17 @@ kernel void FN_NAME_STRIDED( \
return; \
} \
output[tid] = TYPENAME(FN(float(input[get_strided_index(tid, num_dims, dims, strides)]))); \
+} \
+kernel void FN_NAME##_##tiled( \
+ constant size_t &dim, \
+ device const TYPENAME *input, \
+ device TYPENAME *output, \
+ uint tid [[ thread_position_in_grid ]] \
+) { \
+ for (uint i = 0; i < TILE_SIZE; i++) { \
+ const uint idx = tid * TILE_SIZE + i; \
+ output[idx] = TYPENAME(FN(float(input[idx]))); \
+ } \
}
#define UNARY_OP(NAME) \