summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src/reduce.metal
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2023-12-15 01:35:08 +0100
committerNicolas Patry <patry.nicolas@protonmail.com>2023-12-15 01:35:08 +0100
commitece4c69a681215837fd5a008e2ee652394daa8ed (patch)
tree6bb5913a61b770f1d71df1153764058fe2d88bec /candle-metal-kernels/src/reduce.metal
parent4eeaf205d6d0577805a41dc7ae2457be1862726a (diff)
downloadcandle-ece4c69a681215837fd5a008e2ee652394daa8ed.tar.gz
candle-ece4c69a681215837fd5a008e2ee652394daa8ed.tar.bz2
candle-ece4c69a681215837fd5a008e2ee652394daa8ed.zip
Fixing softmax.
Diffstat (limited to 'candle-metal-kernels/src/reduce.metal')
-rw-r--r--candle-metal-kernels/src/reduce.metal11
1 files changed, 7 insertions, 4 deletions
diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal
index 53e4664a..3633fdcf 100644
--- a/candle-metal-kernels/src/reduce.metal
+++ b/candle-metal-kernels/src/reduce.metal
@@ -67,7 +67,6 @@ kernel void NAME( \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
- threadgroup_barrier(mem_flags::mem_none); \
dst[dst_id] = shared_memory[0]; \
} \
@@ -94,11 +93,10 @@ kernel void NAME(
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
size_t idx = start_idx + tid; \
\
- threadgroup_barrier(mem_flags::mem_threadgroup); \
\
- float tmp = 0; \
+ float tmp = -INFINITY; \
while (idx < stop_idx) { \
- tmp = MAX(tmp, src[idx]); \
+ tmp = MAX(tmp, float(src[idx])); \
idx += block_dim; \
} \
shared_memory[tid] = tmp; \
@@ -109,12 +107,15 @@ kernel void NAME(
if (tid < s) { \
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
} \
+ threadgroup_barrier(mem_flags::mem_threadgroup); \
} \
\
+ /* wait for shared_memory[0] to be filled */ \
threadgroup_barrier(mem_flags::mem_threadgroup); \
\
float _max = shared_memory[0]; \
\
+ /* prevent tid=0 from overwriting _max before other threads have written it */ \
threadgroup_barrier(mem_flags::mem_threadgroup); \
shared_memory[tid] = 0; \
\
@@ -125,10 +126,12 @@ kernel void NAME(
shared_memory[tid] += val; \
idx += block_dim; \
} \
+ threadgroup_barrier(mem_flags::mem_threadgroup); \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
shared_memory[tid] += shared_memory[tid + s]; \
} \
+ threadgroup_barrier(mem_flags::mem_threadgroup); \
} \
\
const T inv_acc = T(1.0/shared_memory[0]); \