summaryrefslogtreecommitdiff
path: root/candle-nn/src/ops.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/src/ops.rs')
-rw-r--r--candle-nn/src/ops.rs24
1 files changed, 24 insertions, 0 deletions
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs
index 88196aa7..611c66d8 100644
--- a/candle-nn/src/ops.rs
+++ b/candle-nn/src/ops.rs
@@ -1,5 +1,29 @@
use candle::{Result, Tensor};
+/// Applies the softmax function to the input tensor, rescaling the element so that elements on
+/// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1.
+///
+/// ```rust
+/// use candle::{Tensor, Device};
+/// let a = Tensor::new(&[[0f32, 1., 0., 1.], [-2., 2., 3., -3.]], &Device::Cpu)?;
+/// let a = candle_nn::ops::softmax(&a, 1)?;
+/// assert_eq!(
+/// a.to_vec2::<f32>()?,
+/// &[
+/// [0.13447072, 0.3655293, 0.13447072, 0.3655293],
+/// [0.0048928666, 0.26714146, 0.7261658, 0.0017999851]
+/// ]);
+/// # Ok::<(), candle::Error>(())
+/// ```
+pub fn softmax<D: candle::shape::Dim>(xs: &Tensor, dim: D) -> Result<Tensor> {
+ let dim = dim.to_index(xs.shape(), "softmax")?;
+ let max = xs.max_keepdim(dim)?;
+ let diff = xs.broadcast_sub(&max)?;
+ let num = diff.exp()?;
+ let den = num.sum_keepdim(dim)?;
+ num.broadcast_div(&den)
+}
+
pub fn log_softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
let d = d.to_index(xs.shape(), "log-softmax")?;
let max = xs.max_keepdim(d)?;