diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-01-13 20:24:06 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-13 20:24:06 +0100 |
commit | e6d86b081980196745e5f0b0eda8ce5334c0ff67 (patch) | |
tree | f2680645ff85136d8504bde6f75e2a61cbee22f6 /candle-core/tests/tensor_tests.rs | |
parent | 88618255cb3c20b511a2f0e6db35d84081ce3c4a (diff) | |
download | candle-e6d86b081980196745e5f0b0eda8ce5334c0ff67.tar.gz candle-e6d86b081980196745e5f0b0eda8ce5334c0ff67.tar.bz2 candle-e6d86b081980196745e5f0b0eda8ce5334c0ff67.zip |
Add the pow operator. (#1583)
* Add the pow operator.
* Support the pow operation in onnx.
Diffstat (limited to 'candle-core/tests/tensor_tests.rs')
-rw-r--r-- | candle-core/tests/tensor_tests.rs | 16 |
1 files changed, 14 insertions, 2 deletions
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index e83fb55b..33bab1b6 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1245,11 +1245,23 @@ fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> { } #[test] -fn logsumexp() -> Result<()> { +fn log_sum_exp() -> Result<()> { let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?; - let output = input.logsumexp(D::Minus1)?; + let output = input.log_sum_exp(D::Minus1)?; // The expectations obtained from pytorch. let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?; assert_close(&output, &expected, 0.00001)?; Ok(()) } + +#[test] +fn pow() -> Result<()> { + let lhs = Tensor::new(&[[1f32, 2., 3.], [4., 5., 6.]], &Device::Cpu)?; + let rhs = (&lhs - 2.)?; + let res = lhs.pow(&rhs)?; + assert_eq!( + test_utils::to_vec2_round(&res, 4)?, + [[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0001]] + ); + Ok(()) +} |