summaryrefslogtreecommitdiff
path: root/candle-core/src/tensor.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r--candle-core/src/tensor.rs16
1 files changed, 13 insertions, 3 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 82532f20..e8b02605 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -2440,9 +2440,19 @@ impl Tensor {
/// Returns log(sum(exp(tensor), dim)).
pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
- let exp = self.exp()?;
- let sum = exp.sum(sum_dims)?;
- sum.log()
+ let sum_dims = sum_dims.to_indexes(self.shape(), "log-sum-exp")?;
+ if sum_dims.is_empty() {
+ return Ok(self.clone());
+ }
+ let max = sum_dims[1..]
+ .iter()
+ .try_fold(self.max_keepdim(sum_dims[0])?, |max, &dim| {
+ max.max_keepdim(dim)
+ })?;
+ let exp = self.broadcast_sub(&max)?.exp()?;
+ let sum = exp.sum(sum_dims.clone())?;
+
+ sum.log()? + max.squeeze_dims(&sum_dims)
}
/// Pointwise pow operation.