diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-05 16:22:27 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-05 15:22:27 +0100 |
commit | 6615daf2425bbf33f0e1f97d2a18534e6bdb9fc3 (patch) | |
tree | 9c61a62040ee8e2e384a97ce647d810a85a1f9b1 /candle-nn/src/ops.rs | |
parent | 1c9e5394a5056aadc948f9330ea31fea4972e65e (diff) | |
download | candle-6615daf2425bbf33f0e1f97d2a18534e6bdb9fc3.tar.gz candle-6615daf2425bbf33f0e1f97d2a18534e6bdb9fc3.tar.bz2 candle-6615daf2425bbf33f0e1f97d2a18534e6bdb9fc3.zip |
Tweaks to softmax. (#745)
Diffstat (limited to 'candle-nn/src/ops.rs')
-rw-r--r-- | candle-nn/src/ops.rs | 8 |
1 files changed, 3 insertions, 5 deletions
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 55da46f8..73214077 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -103,14 +103,12 @@ impl candle::CustomOp1 for SoftmaxLastDim { .zip(dst.par_chunks_mut(dim_m1)) .for_each(|(src, dst)| { let mut max = T::neg_infinity(); - for &s in src.iter() { - max = T::max(s, max) - } - let mut sum_exp = T::zero(); + unsafe { T::vec_reduce_max(src.as_ptr(), &mut max, dim_m1) }; for (s, d) in src.iter().zip(dst.iter_mut()) { *d = (*s - max).exp(); - sum_exp += *d } + let mut sum_exp = T::zero(); + unsafe { T::vec_reduce_sum(dst.as_ptr(), &mut sum_exp, dim_m1) }; for d in dst.iter_mut() { *d /= sum_exp } |