diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-04-04 16:28:23 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-04 16:28:23 +0200 |
commit | 30b145150f47cc21b51e04adf03ce41995ff729f (patch) | |
tree | 73f96edd36d510024f9f6b31e9145e44dce1e213 | |
parent | f48c07e2428a6d777ffdea57a2d1ac6a7d58a8ee (diff) | |
download | candle-30b145150f47cc21b51e04adf03ce41995ff729f.tar.gz candle-30b145150f47cc21b51e04adf03ce41995ff729f.tar.bz2 candle-30b145150f47cc21b51e04adf03ce41995ff729f.zip |
Optimize the gelu f16 opt. (#2008)
* Optimize the gelu f16 opt.
* And add a test.
-rw-r--r-- | candle-core/src/op.rs | 19 | ||||
-rw-r--r-- | candle-core/tests/tensor_tests.rs | 8 |
2 files changed, 19 insertions, 8 deletions
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 3b34eb75..776f5182 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -457,6 +457,13 @@ unary_op!(Recip, "recip", v, v.recip()); unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr); unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt); +// Hardcode the value for sqrt(2/pi) +// https://github.com/huggingface/candle/issues/1982 +#[allow(clippy::excessive_precision)] +const SQRT_TWO_OVER_PI_F32: f32 = 0.79788456080286535587989211986876373; +#[allow(clippy::excessive_precision)] +const SQRT_TWO_OVER_PI_F64: f64 = 0.79788456080286535587989211986876373; + /// Tanh based approximation of the `gelu` operation /// GeluErf is the more precise one. /// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions> @@ -469,7 +476,7 @@ impl UnaryOpT for Gelu { * v * (bf16::ONE + bf16::tanh( - (bf16::from_f32_const(2.0) / bf16::PI).sqrt() + bf16::from_f32_const(SQRT_TWO_OVER_PI_F32) * v * (bf16::ONE + bf16::from_f32_const(0.044715) * v * v), )) @@ -480,22 +487,18 @@ impl UnaryOpT for Gelu { * v * (f16::ONE + f16::tanh( - (f16::from_f32_const(2.0) / f16::PI).sqrt() + f16::from_f32_const(SQRT_TWO_OVER_PI_F32) * v * (f16::ONE + f16::from_f32_const(0.044715) * v * v), )) } #[inline(always)] fn f32(v: f32) -> f32 { - 0.5 * v - * (1.0 - + f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v))) + 0.5 * v * (1.0 + f32::tanh(SQRT_TWO_OVER_PI_F32 * v * (1.0 + 0.044715 * v * v))) } #[inline(always)] fn f64(v: f64) -> f64 { - 0.5 * v - * (1.0 - + f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v))) + 0.5 * v * (1.0 + f64::tanh(SQRT_TWO_OVER_PI_F64 * v * (1.0 + 0.044715 * v * v))) } #[inline(always)] fn u8(_: u8) -> u8 { diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 902b84f7..1e2c1c77 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -106,6 +106,14 @@ fn unary_op(device: &Device) -> Result<()> { [2.6911, -0.0647, -0.1091, 1.7353, 2.7933] ] ); + let t_f16 = tensor.to_dtype(DType::F16)?.gelu()?.to_dtype(DType::F32)?; + assert_eq!( + test_utils::to_vec2_round(&t_f16, 2)?, + [ + [-0.0, 0.84, 4.0, -0.05, 0.35], + [2.69, -0.07, -0.11, 1.73, 2.79] + ], + ); assert_eq!( test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?, [ |