diff options
author | Nicolas Patry <patry.nicolas@protonmail.com> | 2023-12-15 01:35:08 +0100 |
---|---|---|
committer | Nicolas Patry <patry.nicolas@protonmail.com> | 2023-12-15 01:35:08 +0100 |
commit | ece4c69a681215837fd5a008e2ee652394daa8ed (patch) | |
tree | 6bb5913a61b770f1d71df1153764058fe2d88bec /candle-metal-kernels/src/reduce.metal | |
parent | 4eeaf205d6d0577805a41dc7ae2457be1862726a (diff) | |
download | candle-ece4c69a681215837fd5a008e2ee652394daa8ed.tar.gz candle-ece4c69a681215837fd5a008e2ee652394daa8ed.tar.bz2 candle-ece4c69a681215837fd5a008e2ee652394daa8ed.zip |
Fixing softmax.
Diffstat (limited to 'candle-metal-kernels/src/reduce.metal')
-rw-r--r-- | candle-metal-kernels/src/reduce.metal | 11 |
1 files changed, 7 insertions, 4 deletions
diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 53e4664a..3633fdcf 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -67,7 +67,6 @@ kernel void NAME( \ threadgroup_barrier(mem_flags::mem_none); \ } \ \ - threadgroup_barrier(mem_flags::mem_none); \ dst[dst_id] = shared_memory[0]; \ } \ @@ -94,11 +93,10 @@ kernel void NAME( size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \ size_t idx = start_idx + tid; \ \ - threadgroup_barrier(mem_flags::mem_threadgroup); \ \ - float tmp = 0; \ + float tmp = -INFINITY; \ while (idx < stop_idx) { \ - tmp = MAX(tmp, src[idx]); \ + tmp = MAX(tmp, float(src[idx])); \ idx += block_dim; \ } \ shared_memory[tid] = tmp; \ @@ -109,12 +107,15 @@ kernel void NAME( if (tid < s) { \ shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \ } \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ } \ \ + /* wait for shared_memory[0] to be filled */ \ threadgroup_barrier(mem_flags::mem_threadgroup); \ \ float _max = shared_memory[0]; \ \ + /* prevent tid=0 from overwriting _max before other threads have written it */ \ threadgroup_barrier(mem_flags::mem_threadgroup); \ shared_memory[tid] = 0; \ \ @@ -125,10 +126,12 @@ kernel void NAME( shared_memory[tid] += val; \ idx += block_dim; \ } \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ for (uint s = block_dim / 2; s > 0; s >>= 1) { \ if (tid < s) { \ shared_memory[tid] += shared_memory[tid + s]; \ } \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ } \ \ const T inv_acc = T(1.0/shared_memory[0]); \ |