summaryrefslogtreecommitdiff
path: root/candle-core/src/op.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/op.rs')
-rw-r--r--candle-core/src/op.rs19
1 files changed, 11 insertions, 8 deletions
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index 3b34eb75..776f5182 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -457,6 +457,13 @@ unary_op!(Recip, "recip", v, v.recip());
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
+// Hardcode the value for sqrt(2/pi)
+// https://github.com/huggingface/candle/issues/1982
+#[allow(clippy::excessive_precision)]
+const SQRT_TWO_OVER_PI_F32: f32 = 0.79788456080286535587989211986876373;
+#[allow(clippy::excessive_precision)]
+const SQRT_TWO_OVER_PI_F64: f64 = 0.79788456080286535587989211986876373;
+
/// Tanh based approximation of the `gelu` operation
/// GeluErf is the more precise one.
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
@@ -469,7 +476,7 @@ impl UnaryOpT for Gelu {
* v
* (bf16::ONE
+ bf16::tanh(
- (bf16::from_f32_const(2.0) / bf16::PI).sqrt()
+ bf16::from_f32_const(SQRT_TWO_OVER_PI_F32)
* v
* (bf16::ONE + bf16::from_f32_const(0.044715) * v * v),
))
@@ -480,22 +487,18 @@ impl UnaryOpT for Gelu {
* v
* (f16::ONE
+ f16::tanh(
- (f16::from_f32_const(2.0) / f16::PI).sqrt()
+ f16::from_f32_const(SQRT_TWO_OVER_PI_F32)
* v
* (f16::ONE + f16::from_f32_const(0.044715) * v * v),
))
}
#[inline(always)]
fn f32(v: f32) -> f32 {
- 0.5 * v
- * (1.0
- + f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
+ 0.5 * v * (1.0 + f32::tanh(SQRT_TWO_OVER_PI_F32 * v * (1.0 + 0.044715 * v * v)))
}
#[inline(always)]
fn f64(v: f64) -> f64 {
- 0.5 * v
- * (1.0
- + f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
+ 0.5 * v * (1.0 + f64::tanh(SQRT_TWO_OVER_PI_F64 * v * (1.0 + 0.044715 * v * v)))
}
#[inline(always)]
fn u8(_: u8) -> u8 {