diff options
author | Yun-Jhong Wu <yjwuam@gmail.com> | 2024-08-01 03:37:02 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-08-01 10:37:02 +0200 |
commit | bd80078acfe25c15e3f590be4ec68ff8caddcbdd (patch) | |
tree | a631b810082fcb5e3f7cbcae0192419a1c365011 /candle-core/src/tensor.rs | |
parent | fea46cb719d5f59216f5b0a606400f1fd663190e (diff) | |
download | candle-bd80078acfe25c15e3f590be4ec68ff8caddcbdd.tar.gz candle-bd80078acfe25c15e3f590be4ec68ff8caddcbdd.tar.bz2 candle-bd80078acfe25c15e3f590be4ec68ff8caddcbdd.zip |
Fix log_sum_exp to handle large positive/negative inputs (#2367)
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r-- | candle-core/src/tensor.rs | 16 |
1 files changed, 13 insertions, 3 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 82532f20..e8b02605 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2440,9 +2440,19 @@ impl Tensor { /// Returns log(sum(exp(tensor), dim)). pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> { - let exp = self.exp()?; - let sum = exp.sum(sum_dims)?; - sum.log() + let sum_dims = sum_dims.to_indexes(self.shape(), "log-sum-exp")?; + if sum_dims.is_empty() { + return Ok(self.clone()); + } + let max = sum_dims[1..] + .iter() + .try_fold(self.max_keepdim(sum_dims[0])?, |max, &dim| { + max.max_keepdim(dim) + })?; + let exp = self.broadcast_sub(&max)?.exp()?; + let sum = exp.sum(sum_dims.clone())?; + + sum.log()? + max.squeeze_dims(&sum_dims) } /// Pointwise pow operation. |