diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-04-04 08:26:02 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-04 08:26:02 +0200 |
commit | 1e46cf8b1942d496f76e13c53e5bcb4cb73586a5 (patch) | |
tree | e52ff97272400981572184d6c09108f69d2a994d /candle-metal-kernels | |
parent | bd8db2a7712e14ea76a80475905db04bbf402aa6 (diff) | |
download | candle-1e46cf8b1942d496f76e13c53e5bcb4cb73586a5.tar.gz candle-1e46cf8b1942d496f76e13c53e5bcb4cb73586a5.tar.bz2 candle-1e46cf8b1942d496f76e13c53e5bcb4cb73586a5.zip |
Minor cleanups in reduce.metal. (#2004)
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/reduce.metal | 24 |
1 files changed, 1 insertions, 23 deletions
diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index d06efbf2..561d1744 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -37,17 +37,13 @@ METAL_FUNC void argmin( threadgroup uint *shared_indices ) { bool notset = true; - /* // Elements summed in this block range from dst_id * el_to_sum_per_block // to (dst_id + 1) * el_to_sum_per_block. - */ size_t start_idx = dst_id * el_to_sum_per_block; size_t stop_idx = start_idx + el_to_sum_per_block; size_t idx = start_idx + tid; while (idx < stop_idx) { - /* // TODO: Fast version for the contiguous case. - */ size_t strided_i = get_strided_index(idx, num_dims, dims, strides); if (notset || src[strided_i] < shared_memory[tid]) { shared_memory[tid] = src[strided_i]; @@ -59,9 +55,7 @@ METAL_FUNC void argmin( } threadgroup_barrier(mem_flags::mem_none); - /* // reduction in shared memory - */ for (uint s = block_dim / 2; s > 0; s >>= 1) { if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { shared_indices[tid] = shared_indices[tid + s]; @@ -69,8 +63,7 @@ METAL_FUNC void argmin( } \ threadgroup_barrier(mem_flags::mem_none); } - - if (tid == 0){ + if (tid == 0) { dst[dst_id] = shared_indices[0]; } } @@ -111,18 +104,14 @@ METAL_FUNC void argmax( threadgroup T * shared_memory, threadgroup uint * shared_indices ) { - /* // Elements summed in this block range from dst_id * el_to_sum_per_block // to (dst_id + 1) * el_to_sum_per_block. - */ size_t start_idx = dst_id * el_to_sum_per_block; size_t stop_idx = start_idx + el_to_sum_per_block; size_t idx = start_idx + tid; bool notset = true; while (idx < stop_idx) { - /* // TODO: Fast version for the contiguous case. - */ size_t strided_i = get_strided_index(idx, num_dims, dims, strides); if (notset || shared_memory[tid] < src[strided_i]) { shared_memory[tid] = src[strided_i]; @@ -134,9 +123,7 @@ METAL_FUNC void argmax( threadgroup_barrier(mem_flags::mem_none); - /* // reduction in shared memory - */ for (uint s = block_dim / 2; s > 0; s >>= 1) { if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { shared_indices[tid] = shared_indices[tid + s]; @@ -145,9 +132,7 @@ METAL_FUNC void argmax( threadgroup_barrier(mem_flags::mem_none); } - /* // Thread 0 writes the result of the reduction - */ if (tid == 0) { dst[dst_id] = shared_indices[0]; } @@ -188,17 +173,13 @@ METAL_FUNC void reduce( threadgroup T * shared_memory, T (*fn)(T, T) ) { - /* // Elements summed in this block range from dst_id * el_to_sum_per_block // to (dst_id + 1) * el_to_sum_per_block. - */ size_t start_idx = dst_id * el_to_sum_per_block; size_t stop_idx = start_idx + el_to_sum_per_block; size_t idx = start_idx + tid; while (idx < stop_idx) { - /* // TODO: Fast version for the contiguous case. - */ size_t strided_i = get_strided_index(idx, num_dims, dims, strides); T x = shared_memory[tid]; T y = src[strided_i]; @@ -208,9 +189,7 @@ METAL_FUNC void reduce( threadgroup_barrier(mem_flags::mem_none); - /* // reduction in shared memory - */ for (uint s = block_dim / 2; s > 0; s >>= 1) { if (tid < s) { T x = shared_memory[tid]; @@ -277,7 +256,6 @@ METAL_FUNC void softmax( } /* wait for shared_memory[0] to be filled */ - \ threadgroup_barrier(mem_flags::mem_threadgroup); float _max = shared_memory[0]; |