summaryrefslogtreecommitdiff
path: root/candle-core/src/tensor.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r--candle-core/src/tensor.rs7
1 files changed, 7 insertions, 0 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 4d9b0837..8950f2c5 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -2565,6 +2565,13 @@ impl Tensor {
}
mask.where_cond(/* on_true= */ &src, /* on_false= */ self)
}
+
+ /// Returns log(sum(exp(tensor), dim)).
+ pub fn logsumexp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
+ let exp = self.exp()?;
+ let sum = exp.sum(sum_dims)?;
+ sum.log()
+ }
}
macro_rules! bin_trait {