summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-04 08:26:02 +0200
committerGitHub <noreply@github.com>2024-04-04 08:26:02 +0200
commit1e46cf8b1942d496f76e13c53e5bcb4cb73586a5 (patch)
treee52ff97272400981572184d6c09108f69d2a994d /candle-metal-kernels
parentbd8db2a7712e14ea76a80475905db04bbf402aa6 (diff)
downloadcandle-1e46cf8b1942d496f76e13c53e5bcb4cb73586a5.tar.gz
candle-1e46cf8b1942d496f76e13c53e5bcb4cb73586a5.tar.bz2
candle-1e46cf8b1942d496f76e13c53e5bcb4cb73586a5.zip
Minor cleanups in reduce.metal. (#2004)
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/reduce.metal24
1 files changed, 1 insertions, 23 deletions
diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal
index d06efbf2..561d1744 100644
--- a/candle-metal-kernels/src/reduce.metal
+++ b/candle-metal-kernels/src/reduce.metal
@@ -37,17 +37,13 @@ METAL_FUNC void argmin(
threadgroup uint *shared_indices
) {
bool notset = true;
- /*
// Elements summed in this block range from dst_id * el_to_sum_per_block
// to (dst_id + 1) * el_to_sum_per_block.
- */
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = start_idx + el_to_sum_per_block;
size_t idx = start_idx + tid;
while (idx < stop_idx) {
- /*
// TODO: Fast version for the contiguous case.
- */
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
if (notset || src[strided_i] < shared_memory[tid]) {
shared_memory[tid] = src[strided_i];
@@ -59,9 +55,7 @@ METAL_FUNC void argmin(
}
threadgroup_barrier(mem_flags::mem_none);
- /*
// reduction in shared memory
- */
for (uint s = block_dim / 2; s > 0; s >>= 1) {
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) {
shared_indices[tid] = shared_indices[tid + s];
@@ -69,8 +63,7 @@ METAL_FUNC void argmin(
} \
threadgroup_barrier(mem_flags::mem_none);
}
-
- if (tid == 0){
+ if (tid == 0) {
dst[dst_id] = shared_indices[0];
}
}
@@ -111,18 +104,14 @@ METAL_FUNC void argmax(
threadgroup T * shared_memory,
threadgroup uint * shared_indices
) {
- /*
// Elements summed in this block range from dst_id * el_to_sum_per_block
// to (dst_id + 1) * el_to_sum_per_block.
- */
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = start_idx + el_to_sum_per_block;
size_t idx = start_idx + tid;
bool notset = true;
while (idx < stop_idx) {
- /*
// TODO: Fast version for the contiguous case.
- */
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
if (notset || shared_memory[tid] < src[strided_i]) {
shared_memory[tid] = src[strided_i];
@@ -134,9 +123,7 @@ METAL_FUNC void argmax(
threadgroup_barrier(mem_flags::mem_none);
- /*
// reduction in shared memory
- */
for (uint s = block_dim / 2; s > 0; s >>= 1) {
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) {
shared_indices[tid] = shared_indices[tid + s];
@@ -145,9 +132,7 @@ METAL_FUNC void argmax(
threadgroup_barrier(mem_flags::mem_none);
}
- /*
// Thread 0 writes the result of the reduction
- */
if (tid == 0) {
dst[dst_id] = shared_indices[0];
}
@@ -188,17 +173,13 @@ METAL_FUNC void reduce(
threadgroup T * shared_memory,
T (*fn)(T, T)
) {
- /*
// Elements summed in this block range from dst_id * el_to_sum_per_block
// to (dst_id + 1) * el_to_sum_per_block.
- */
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = start_idx + el_to_sum_per_block;
size_t idx = start_idx + tid;
while (idx < stop_idx) {
- /*
// TODO: Fast version for the contiguous case.
- */
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
T x = shared_memory[tid];
T y = src[strided_i];
@@ -208,9 +189,7 @@ METAL_FUNC void reduce(
threadgroup_barrier(mem_flags::mem_none);
- /*
// reduction in shared memory
- */
for (uint s = block_dim / 2; s > 0; s >>= 1) {
if (tid < s) {
T x = shared_memory[tid];
@@ -277,7 +256,6 @@ METAL_FUNC void softmax(
}
/* wait for shared_memory[0] to be filled */
- \
threadgroup_barrier(mem_flags::mem_threadgroup);
float _max = shared_memory[0];