diff options
-rw-r--r-- | candle-core/src/tensor.rs | 16 | ||||
-rw-r--r-- | candle-core/tests/tensor_tests.rs | 24 |
2 files changed, 34 insertions, 6 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 82532f20..e8b02605 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2440,9 +2440,19 @@ impl Tensor { /// Returns log(sum(exp(tensor), dim)). pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> { - let exp = self.exp()?; - let sum = exp.sum(sum_dims)?; - sum.log() + let sum_dims = sum_dims.to_indexes(self.shape(), "log-sum-exp")?; + if sum_dims.is_empty() { + return Ok(self.clone()); + } + let max = sum_dims[1..] + .iter() + .try_fold(self.max_keepdim(sum_dims[0])?, |max, &dim| { + max.max_keepdim(dim) + })?; + let exp = self.broadcast_sub(&max)?.exp()?; + let sum = exp.sum(sum_dims.clone())?; + + sum.log()? + max.squeeze_dims(&sum_dims) } /// Pointwise pow operation. diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index cd5f4ca1..567b49f1 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1326,11 +1326,29 @@ fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> { #[test] fn log_sum_exp() -> Result<()> { - let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?; + let input = Tensor::new( + &[ + [[1f64, 2., 3.], [4., 5., 6.]], + [[-1000.0, -999.0, -1001.0], [1000.0, 999.0, 1001.0]], + ], + &Device::Cpu, + )?; + let output = input.log_sum_exp(D::Minus1)?; // The expectations obtained from pytorch. - let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?; - assert_close(&output, &expected, 0.00001)?; + let expected = Tensor::new(&[[3.4076, 6.4076], [-998.5924, 1001.4076]], &Device::Cpu)?; + assert_eq!(output.dims(), expected.dims()); + assert_close(&output.flatten_all()?, &expected.flatten_all()?, 0.00001)?; + + assert_eq!( + input.log_sum_exp((0, 1))?.to_vec1::<f64>()?, + [1000.0, 999.0, 1001.0] + ); + assert_eq!( + input.log_sum_exp(())?.to_vec3::<f64>()?, + input.to_vec3::<f64>()? + ); + Ok(()) } |