summaryrefslogtreecommitdiff
path: root/candle-core/tests/tensor_tests.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/tests/tensor_tests.rs')
-rw-r--r--candle-core/tests/tensor_tests.rs8
1 files changed, 8 insertions, 0 deletions
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index 902b84f7..1e2c1c77 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -106,6 +106,14 @@ fn unary_op(device: &Device) -> Result<()> {
[2.6911, -0.0647, -0.1091, 1.7353, 2.7933]
]
);
+ let t_f16 = tensor.to_dtype(DType::F16)?.gelu()?.to_dtype(DType::F32)?;
+ assert_eq!(
+ test_utils::to_vec2_round(&t_f16, 2)?,
+ [
+ [-0.0, 0.84, 4.0, -0.05, 0.35],
+ [2.69, -0.07, -0.11, 1.73, 2.79]
+ ],
+ );
assert_eq!(
test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
[