diff options
Diffstat (limited to 'candle-nn/src')
-rw-r--r-- | candle-nn/src/ops.rs | 8 |
1 files changed, 3 insertions, 5 deletions
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 55da46f8..73214077 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -103,14 +103,12 @@ impl candle::CustomOp1 for SoftmaxLastDim { .zip(dst.par_chunks_mut(dim_m1)) .for_each(|(src, dst)| { let mut max = T::neg_infinity(); - for &s in src.iter() { - max = T::max(s, max) - } - let mut sum_exp = T::zero(); + unsafe { T::vec_reduce_max(src.as_ptr(), &mut max, dim_m1) }; for (s, d) in src.iter().zip(dst.iter_mut()) { *d = (*s - max).exp(); - sum_exp += *d } + let mut sum_exp = T::zero(); + unsafe { T::vec_reduce_sum(dst.as_ptr(), &mut sum_exp, dim_m1) }; for d in dst.iter_mut() { *d /= sum_exp } |