summaryrefslogtreecommitdiff
path: root/candle-core/tests/tensor_tests.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/tests/tensor_tests.rs')
-rw-r--r--candle-core/tests/tensor_tests.rs13
1 files changed, 4 insertions, 9 deletions
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index 1e2c1c77..b3275804 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -107,13 +107,8 @@ fn unary_op(device: &Device) -> Result<()> {
]
);
let t_f16 = tensor.to_dtype(DType::F16)?.gelu()?.to_dtype(DType::F32)?;
- assert_eq!(
- test_utils::to_vec2_round(&t_f16, 2)?,
- [
- [-0.0, 0.84, 4.0, -0.05, 0.35],
- [2.69, -0.07, -0.11, 1.73, 2.79]
- ],
- );
+ let max_diff = (tensor.gelu()? - t_f16)?.flatten_all()?.max(0)?;
+ assert!(max_diff.to_vec0::<f32>()? < 5e-3);
assert_eq!(
test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
[
@@ -1255,8 +1250,8 @@ fn pow() -> Result<()> {
let rhs = (&lhs - 2.)?;
let res = lhs.pow(&rhs)?;
assert_eq!(
- test_utils::to_vec2_round(&res, 4)?,
- [[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0001]]
+ test_utils::to_vec2_round(&res, 3)?,
+ [[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0]]
);
Ok(())
}