diff options
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r-- | candle-core/src/tensor.rs | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 4d9b0837..8950f2c5 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2565,6 +2565,13 @@ impl Tensor { } mask.where_cond(/* on_true= */ &src, /* on_false= */ self) } + + /// Returns log(sum(exp(tensor), dim)). + pub fn logsumexp<D: Dims>(&self, sum_dims: D) -> Result<Self> { + let exp = self.exp()?; + let sum = exp.sum(sum_dims)?; + sum.log() + } } macro_rules! bin_trait { |