summaryrefslogtreecommitdiff
path: root/candle-nn/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/src')
-rw-r--r--candle-nn/src/ops.rs8
1 files changed, 3 insertions, 5 deletions
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs
index 55da46f8..73214077 100644
--- a/candle-nn/src/ops.rs
+++ b/candle-nn/src/ops.rs
@@ -103,14 +103,12 @@ impl candle::CustomOp1 for SoftmaxLastDim {
.zip(dst.par_chunks_mut(dim_m1))
.for_each(|(src, dst)| {
let mut max = T::neg_infinity();
- for &s in src.iter() {
- max = T::max(s, max)
- }
- let mut sum_exp = T::zero();
+ unsafe { T::vec_reduce_max(src.as_ptr(), &mut max, dim_m1) };
for (s, d) in src.iter().zip(dst.iter_mut()) {
*d = (*s - max).exp();
- sum_exp += *d
}
+ let mut sum_exp = T::zero();
+ unsafe { T::vec_reduce_sum(dst.as_ptr(), &mut sum_exp, dim_m1) };
for d in dst.iter_mut() {
*d /= sum_exp
}