summaryrefslogtreecommitdiff
path: root/candle-core/tests
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-04 16:28:23 +0200
committerGitHub <noreply@github.com>2024-04-04 16:28:23 +0200
commit30b145150f47cc21b51e04adf03ce41995ff729f (patch)
tree73f96edd36d510024f9f6b31e9145e44dce1e213 /candle-core/tests
parentf48c07e2428a6d777ffdea57a2d1ac6a7d58a8ee (diff)
downloadcandle-30b145150f47cc21b51e04adf03ce41995ff729f.tar.gz
candle-30b145150f47cc21b51e04adf03ce41995ff729f.tar.bz2
candle-30b145150f47cc21b51e04adf03ce41995ff729f.zip
Optimize the gelu f16 opt. (#2008)
* Optimize the gelu f16 opt. * And add a test.
Diffstat (limited to 'candle-core/tests')
-rw-r--r--candle-core/tests/tensor_tests.rs8
1 files changed, 8 insertions, 0 deletions
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index 902b84f7..1e2c1c77 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -106,6 +106,14 @@ fn unary_op(device: &Device) -> Result<()> {
[2.6911, -0.0647, -0.1091, 1.7353, 2.7933]
]
);
+ let t_f16 = tensor.to_dtype(DType::F16)?.gelu()?.to_dtype(DType::F32)?;
+ assert_eq!(
+ test_utils::to_vec2_round(&t_f16, 2)?,
+ [
+ [-0.0, 0.84, 4.0, -0.05, 0.35],
+ [2.69, -0.07, -0.11, 1.73, 2.79]
+ ],
+ );
assert_eq!(
test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
[