diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-12-12 10:56:11 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-12 10:56:11 -0600 |
commit | 4cb443d00a83d3de994b6c799d501f2ffc2b034b (patch) | |
tree | d2853f9c5bf9d944f425965132fb6bf856186be6 /candle-core | |
parent | 77252ffb82e328322951becda5fef1e261daa9a9 (diff) | |
download | candle-4cb443d00a83d3de994b6c799d501f2ffc2b034b.tar.gz candle-4cb443d00a83d3de994b6c799d501f2ffc2b034b.tar.bz2 candle-4cb443d00a83d3de994b6c799d501f2ffc2b034b.zip |
Fix the logsumexp test. (#1426)
Diffstat (limited to 'candle-core')
-rw-r--r-- | candle-core/tests/tensor_tests.rs | 20 |
1 files changed, 9 insertions, 11 deletions
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 95eadc24..a4548d56 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1,4 +1,4 @@ -use candle_core::{test_device, test_utils, D, DType, Device, IndexOp, Result, Tensor}; +use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor, D}; fn zeros(device: &Device) -> Result<()> { let tensor = Tensor::zeros((5, 2), DType::F32, device)?; @@ -1224,25 +1224,23 @@ fn cumsum() -> Result<()> { /// A helper function for floating point comparison. Both a and b must be 1D Tensor and contains the same amount of data. /// Assertion passes if the difference of all pairs of a and b is smaller than epsilon. -fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) { - let a_vec: Vec<f64> = a.to_vec1().unwrap(); - let b_vec: Vec<f64> = b.to_vec1().unwrap(); +fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> { + let a_vec: Vec<f64> = a.to_vec1()?; + let b_vec: Vec<f64> = b.to_vec1()?; assert_eq!(a_vec.len(), b_vec.len()); for (a, b) in a_vec.iter().zip(b_vec.iter()) { assert!((a - b).abs() < epsilon); } + Ok(()) } #[test] fn logsumexp() -> Result<()> { - let input = Tensor::new(&[[1f32, 2., 3.], [4., 5., 6.]], &Device::Cpu)?; + let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?; let output = input.logsumexp(D::Minus1)?; - - // Expectation get from pytorch. + // The expectations obtained from pytorch. let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?; - - assert_close(&output, &expected, 0.00001); - + assert_close(&output, &expected, 0.00001)?; Ok(()) -}
\ No newline at end of file +} |