diff options
Diffstat (limited to 'candle-core/tests/tensor_tests.rs')
-rw-r--r-- | candle-core/tests/tensor_tests.rs | 24 |
1 files changed, 21 insertions, 3 deletions
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index cd5f4ca1..567b49f1 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1326,11 +1326,29 @@ fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> { #[test] fn log_sum_exp() -> Result<()> { - let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?; + let input = Tensor::new( + &[ + [[1f64, 2., 3.], [4., 5., 6.]], + [[-1000.0, -999.0, -1001.0], [1000.0, 999.0, 1001.0]], + ], + &Device::Cpu, + )?; + let output = input.log_sum_exp(D::Minus1)?; // The expectations obtained from pytorch. - let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?; - assert_close(&output, &expected, 0.00001)?; + let expected = Tensor::new(&[[3.4076, 6.4076], [-998.5924, 1001.4076]], &Device::Cpu)?; + assert_eq!(output.dims(), expected.dims()); + assert_close(&output.flatten_all()?, &expected.flatten_all()?, 0.00001)?; + + assert_eq!( + input.log_sum_exp((0, 1))?.to_vec1::<f64>()?, + [1000.0, 999.0, 1001.0] + ); + assert_eq!( + input.log_sum_exp(())?.to_vec3::<f64>()?, + input.to_vec3::<f64>()? + ); + Ok(()) } |