summaryrefslogtreecommitdiff
path: root/candle-core/src/tensor.rs
diff options
context:
space:
mode:
authorYun-Jhong Wu <yjwuam@gmail.com>2024-08-01 03:37:02 -0500
committerGitHub <noreply@github.com>2024-08-01 10:37:02 +0200
commitbd80078acfe25c15e3f590be4ec68ff8caddcbdd (patch)
treea631b810082fcb5e3f7cbcae0192419a1c365011 /candle-core/src/tensor.rs
parentfea46cb719d5f59216f5b0a606400f1fd663190e (diff)
downloadcandle-bd80078acfe25c15e3f590be4ec68ff8caddcbdd.tar.gz
candle-bd80078acfe25c15e3f590be4ec68ff8caddcbdd.tar.bz2
candle-bd80078acfe25c15e3f590be4ec68ff8caddcbdd.zip
Fix log_sum_exp to handle large positive/negative inputs (#2367)
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r--candle-core/src/tensor.rs16
1 files changed, 13 insertions, 3 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 82532f20..e8b02605 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -2440,9 +2440,19 @@ impl Tensor {
/// Returns log(sum(exp(tensor), dim)).
pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
- let exp = self.exp()?;
- let sum = exp.sum(sum_dims)?;
- sum.log()
+ let sum_dims = sum_dims.to_indexes(self.shape(), "log-sum-exp")?;
+ if sum_dims.is_empty() {
+ return Ok(self.clone());
+ }
+ let max = sum_dims[1..]
+ .iter()
+ .try_fold(self.max_keepdim(sum_dims[0])?, |max, &dim| {
+ max.max_keepdim(dim)
+ })?;
+ let exp = self.broadcast_sub(&max)?.exp()?;
+ let sum = exp.sum(sum_dims.clone())?;
+
+ sum.log()? + max.squeeze_dims(&sum_dims)
}
/// Pointwise pow operation.