summaryrefslogtreecommitdiff
path: root/candle-core
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core')
-rw-r--r--candle-core/src/tensor.rs16
-rw-r--r--candle-core/tests/tensor_tests.rs24
2 files changed, 34 insertions, 6 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 82532f20..e8b02605 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -2440,9 +2440,19 @@ impl Tensor {
/// Returns log(sum(exp(tensor), dim)).
pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
- let exp = self.exp()?;
- let sum = exp.sum(sum_dims)?;
- sum.log()
+ let sum_dims = sum_dims.to_indexes(self.shape(), "log-sum-exp")?;
+ if sum_dims.is_empty() {
+ return Ok(self.clone());
+ }
+ let max = sum_dims[1..]
+ .iter()
+ .try_fold(self.max_keepdim(sum_dims[0])?, |max, &dim| {
+ max.max_keepdim(dim)
+ })?;
+ let exp = self.broadcast_sub(&max)?.exp()?;
+ let sum = exp.sum(sum_dims.clone())?;
+
+ sum.log()? + max.squeeze_dims(&sum_dims)
}
/// Pointwise pow operation.
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index cd5f4ca1..567b49f1 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -1326,11 +1326,29 @@ fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> {
#[test]
fn log_sum_exp() -> Result<()> {
- let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
+ let input = Tensor::new(
+ &[
+ [[1f64, 2., 3.], [4., 5., 6.]],
+ [[-1000.0, -999.0, -1001.0], [1000.0, 999.0, 1001.0]],
+ ],
+ &Device::Cpu,
+ )?;
+
let output = input.log_sum_exp(D::Minus1)?;
// The expectations obtained from pytorch.
- let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?;
- assert_close(&output, &expected, 0.00001)?;
+ let expected = Tensor::new(&[[3.4076, 6.4076], [-998.5924, 1001.4076]], &Device::Cpu)?;
+ assert_eq!(output.dims(), expected.dims());
+ assert_close(&output.flatten_all()?, &expected.flatten_all()?, 0.00001)?;
+
+ assert_eq!(
+ input.log_sum_exp((0, 1))?.to_vec1::<f64>()?,
+ [1000.0, 999.0, 1001.0]
+ );
+ assert_eq!(
+ input.log_sum_exp(())?.to_vec3::<f64>()?,
+ input.to_vec3::<f64>()?
+ );
+
Ok(())
}