summaryrefslogtreecommitdiff
path: root/candle-core/tests
diff options
context:
space:
mode:
authorYun-Jhong Wu <yjwuam@gmail.com>2024-08-01 03:37:02 -0500
committerGitHub <noreply@github.com>2024-08-01 10:37:02 +0200
commitbd80078acfe25c15e3f590be4ec68ff8caddcbdd (patch)
treea631b810082fcb5e3f7cbcae0192419a1c365011 /candle-core/tests
parentfea46cb719d5f59216f5b0a606400f1fd663190e (diff)
downloadcandle-bd80078acfe25c15e3f590be4ec68ff8caddcbdd.tar.gz
candle-bd80078acfe25c15e3f590be4ec68ff8caddcbdd.tar.bz2
candle-bd80078acfe25c15e3f590be4ec68ff8caddcbdd.zip
Fix log_sum_exp to handle large positive/negative inputs (#2367)
Diffstat (limited to 'candle-core/tests')
-rw-r--r--candle-core/tests/tensor_tests.rs24
1 files changed, 21 insertions, 3 deletions
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index cd5f4ca1..567b49f1 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -1326,11 +1326,29 @@ fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> {
#[test]
fn log_sum_exp() -> Result<()> {
- let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
+ let input = Tensor::new(
+ &[
+ [[1f64, 2., 3.], [4., 5., 6.]],
+ [[-1000.0, -999.0, -1001.0], [1000.0, 999.0, 1001.0]],
+ ],
+ &Device::Cpu,
+ )?;
+
let output = input.log_sum_exp(D::Minus1)?;
// The expectations obtained from pytorch.
- let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?;
- assert_close(&output, &expected, 0.00001)?;
+ let expected = Tensor::new(&[[3.4076, 6.4076], [-998.5924, 1001.4076]], &Device::Cpu)?;
+ assert_eq!(output.dims(), expected.dims());
+ assert_close(&output.flatten_all()?, &expected.flatten_all()?, 0.00001)?;
+
+ assert_eq!(
+ input.log_sum_exp((0, 1))?.to_vec1::<f64>()?,
+ [1000.0, 999.0, 1001.0]
+ );
+ assert_eq!(
+ input.log_sum_exp(())?.to_vec3::<f64>()?,
+ input.to_vec3::<f64>()?
+ );
+
Ok(())
}