summaryrefslogtreecommitdiff
path: root/candle-nn/src/ops.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-05 16:22:27 +0200
committerGitHub <noreply@github.com>2023-09-05 15:22:27 +0100
commit6615daf2425bbf33f0e1f97d2a18534e6bdb9fc3 (patch)
tree9c61a62040ee8e2e384a97ce647d810a85a1f9b1 /candle-nn/src/ops.rs
parent1c9e5394a5056aadc948f9330ea31fea4972e65e (diff)
downloadcandle-6615daf2425bbf33f0e1f97d2a18534e6bdb9fc3.tar.gz
candle-6615daf2425bbf33f0e1f97d2a18534e6bdb9fc3.tar.bz2
candle-6615daf2425bbf33f0e1f97d2a18534e6bdb9fc3.zip
Tweaks to softmax. (#745)
Diffstat (limited to 'candle-nn/src/ops.rs')
-rw-r--r--candle-nn/src/ops.rs8
1 files changed, 3 insertions, 5 deletions
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs
index 55da46f8..73214077 100644
--- a/candle-nn/src/ops.rs
+++ b/candle-nn/src/ops.rs
@@ -103,14 +103,12 @@ impl candle::CustomOp1 for SoftmaxLastDim {
.zip(dst.par_chunks_mut(dim_m1))
.for_each(|(src, dst)| {
let mut max = T::neg_infinity();
- for &s in src.iter() {
- max = T::max(s, max)
- }
- let mut sum_exp = T::zero();
+ unsafe { T::vec_reduce_max(src.as_ptr(), &mut max, dim_m1) };
for (s, d) in src.iter().zip(dst.iter_mut()) {
*d = (*s - max).exp();
- sum_exp += *d
}
+ let mut sum_exp = T::zero();
+ unsafe { T::vec_reduce_sum(dst.as_ptr(), &mut sum_exp, dim_m1) };
for d in dst.iter_mut() {
*d /= sum_exp
}