diff options
author | Yun-Jhong Wu <yjwuam@gmail.com> | 2024-08-01 03:37:02 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-08-01 10:37:02 +0200 |
commit | bd80078acfe25c15e3f590be4ec68ff8caddcbdd (patch) | |
tree | a631b810082fcb5e3f7cbcae0192419a1c365011 | |
parent | fea46cb719d5f59216f5b0a606400f1fd663190e (diff) | |
download | candle-bd80078acfe25c15e3f590be4ec68ff8caddcbdd.tar.gz candle-bd80078acfe25c15e3f590be4ec68ff8caddcbdd.tar.bz2 candle-bd80078acfe25c15e3f590be4ec68ff8caddcbdd.zip |
Fix log_sum_exp to handle large positive/negative inputs (#2367)
-rw-r--r-- | candle-core/src/tensor.rs | 16 | ||||
-rw-r--r-- | candle-core/tests/tensor_tests.rs | 24 |
2 files changed, 34 insertions, 6 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 82532f20..e8b02605 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2440,9 +2440,19 @@ impl Tensor { /// Returns log(sum(exp(tensor), dim)). pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> { - let exp = self.exp()?; - let sum = exp.sum(sum_dims)?; - sum.log() + let sum_dims = sum_dims.to_indexes(self.shape(), "log-sum-exp")?; + if sum_dims.is_empty() { + return Ok(self.clone()); + } + let max = sum_dims[1..] + .iter() + .try_fold(self.max_keepdim(sum_dims[0])?, |max, &dim| { + max.max_keepdim(dim) + })?; + let exp = self.broadcast_sub(&max)?.exp()?; + let sum = exp.sum(sum_dims.clone())?; + + sum.log()? + max.squeeze_dims(&sum_dims) } /// Pointwise pow operation. diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index cd5f4ca1..567b49f1 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1326,11 +1326,29 @@ fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> { #[test] fn log_sum_exp() -> Result<()> { - let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?; + let input = Tensor::new( + &[ + [[1f64, 2., 3.], [4., 5., 6.]], + [[-1000.0, -999.0, -1001.0], [1000.0, 999.0, 1001.0]], + ], + &Device::Cpu, + )?; + let output = input.log_sum_exp(D::Minus1)?; // The expectations obtained from pytorch. - let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?; - assert_close(&output, &expected, 0.00001)?; + let expected = Tensor::new(&[[3.4076, 6.4076], [-998.5924, 1001.4076]], &Device::Cpu)?; + assert_eq!(output.dims(), expected.dims()); + assert_close(&output.flatten_all()?, &expected.flatten_all()?, 0.00001)?; + + assert_eq!( + input.log_sum_exp((0, 1))?.to_vec1::<f64>()?, + [1000.0, 999.0, 1001.0] + ); + assert_eq!( + input.log_sum_exp(())?.to_vec3::<f64>()?, + input.to_vec3::<f64>()? + ); + Ok(()) } |