diff options
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r-- | candle-core/src/tensor.rs | 16 |
1 files changed, 13 insertions, 3 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 82532f20..e8b02605 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2440,9 +2440,19 @@ impl Tensor { /// Returns log(sum(exp(tensor), dim)). pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> { - let exp = self.exp()?; - let sum = exp.sum(sum_dims)?; - sum.log() + let sum_dims = sum_dims.to_indexes(self.shape(), "log-sum-exp")?; + if sum_dims.is_empty() { + return Ok(self.clone()); + } + let max = sum_dims[1..] + .iter() + .try_fold(self.max_keepdim(sum_dims[0])?, |max, &dim| { + max.max_keepdim(dim) + })?; + let exp = self.broadcast_sub(&max)?.exp()?; + let sum = exp.sum(sum_dims.clone())?; + + sum.log()? + max.squeeze_dims(&sum_dims) } /// Pointwise pow operation. |