diff options
Diffstat (limited to 'candle-nn/src')
-rw-r--r-- | candle-nn/src/ops.rs | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 88196aa7..611c66d8 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -1,5 +1,29 @@ use candle::{Result, Tensor}; +/// Applies the softmax function to the input tensor, rescaling the element so that elements on +/// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1. +/// +/// ```rust +/// use candle::{Tensor, Device}; +/// let a = Tensor::new(&[[0f32, 1., 0., 1.], [-2., 2., 3., -3.]], &Device::Cpu)?; +/// let a = candle_nn::ops::softmax(&a, 1)?; +/// assert_eq!( +/// a.to_vec2::<f32>()?, +/// &[ +/// [0.13447072, 0.3655293, 0.13447072, 0.3655293], +/// [0.0048928666, 0.26714146, 0.7261658, 0.0017999851] +/// ]); +/// # Ok::<(), candle::Error>(()) +/// ``` +pub fn softmax<D: candle::shape::Dim>(xs: &Tensor, dim: D) -> Result<Tensor> { + let dim = dim.to_index(xs.shape(), "softmax")?; + let max = xs.max_keepdim(dim)?; + let diff = xs.broadcast_sub(&max)?; + let num = diff.exp()?; + let den = num.sum_keepdim(dim)?; + num.broadcast_div(&den) +} + pub fn log_softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> { let d = d.to_index(xs.shape(), "log-softmax")?; let max = xs.max_keepdim(d)?; |