diff options
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r-- | candle-core/src/tensor.rs | 34 |
1 files changed, 0 insertions, 34 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 09f61340..8ae92c2e 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -553,40 +553,6 @@ impl Tensor { } } - /// Applies the softmax function to the input tensor, rescaling the element so that elements on - /// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1. - /// - /// ```rust - /// use candle::{Tensor, Device}; - /// let a = Tensor::new(&[[0f32, 1., 0., 1.], [-2., 2., 3., -3.]], &Device::Cpu)?; - /// let a = a.softmax(1)?; - /// assert_eq!( - /// a.to_vec2::<f32>()?, - /// &[ - /// [0.13447072, 0.3655293, 0.13447072, 0.3655293], - /// [0.004892866, 0.26714143, 0.7261657, 0.0017999847], - /// ]); - /// # Ok::<(), candle::Error>(()) - /// ``` - pub fn softmax<D: Dim>(&self, dim: D) -> Result<Self> { - let dim = dim.to_index(self.shape(), "softmax")?; - // TODO: unify the two branches. - if self.device().is_cuda() { - // We do not have a cuda kernel for divide_by_sum_over_dim so split - // the operation. - let exp = self.exp()?; - let sum_exp = exp.sum_keepdim(dim)?; - exp.broadcast_div(&sum_exp) - } else { - let shape = self.shape(); - let mut storage = self.storage().unary_impl::<crate::op::Exp>(self.layout())?; - // The resulting storage is contiguous. - storage.divide_by_sum_over_dim(shape, dim)?; - let op = BackpropOp::new1(self, |arg| Op::Softmax(arg, dim)); - Ok(from_storage(storage, shape.clone(), op, false)) - } - } - fn squeeze_dims(self, dims: &[usize]) -> Result<Self> { match dims { [] => Ok(self), |