summaryrefslogtreecommitdiff
path: root/candle-core/tests/tensor_tests.rs
diff options
context:
space:
mode:
authorWenqing Zong <wenqing.zong98@gmail.com>2023-12-12 16:32:17 +0000
committerGitHub <noreply@github.com>2023-12-12 10:32:17 -0600
commit77252ffb82e328322951becda5fef1e261daa9a9 (patch)
treeae9ca7a09a015fb95fa10756bff172ab8d455a76 /candle-core/tests/tensor_tests.rs
parent18eb87f25f1ff58570436bb6a9723949f816b10b (diff)
downloadcandle-77252ffb82e328322951becda5fef1e261daa9a9.tar.gz
candle-77252ffb82e328322951becda5fef1e261daa9a9.tar.bz2
candle-77252ffb82e328322951becda5fef1e261daa9a9.zip
Add logsumexp function (#1424)
Diffstat (limited to 'candle-core/tests/tensor_tests.rs')
-rw-r--r--candle-core/tests/tensor_tests.rs27
1 files changed, 26 insertions, 1 deletions
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index c871dc96..95eadc24 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -1,4 +1,4 @@
-use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor};
+use candle_core::{test_device, test_utils, D, DType, Device, IndexOp, Result, Tensor};
fn zeros(device: &Device) -> Result<()> {
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
@@ -1221,3 +1221,28 @@ fn cumsum() -> Result<()> {
);
Ok(())
}
+
+/// A helper function for floating point comparison. Both a and b must be 1D Tensor and contains the same amount of data.
+/// Assertion passes if the difference of all pairs of a and b is smaller than epsilon.
+fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) {
+ let a_vec: Vec<f64> = a.to_vec1().unwrap();
+ let b_vec: Vec<f64> = b.to_vec1().unwrap();
+
+ assert_eq!(a_vec.len(), b_vec.len());
+ for (a, b) in a_vec.iter().zip(b_vec.iter()) {
+ assert!((a - b).abs() < epsilon);
+ }
+}
+
+#[test]
+fn logsumexp() -> Result<()> {
+ let input = Tensor::new(&[[1f32, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
+ let output = input.logsumexp(D::Minus1)?;
+
+ // Expectation get from pytorch.
+ let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?;
+
+ assert_close(&output, &expected, 0.00001);
+
+ Ok(())
+} \ No newline at end of file