diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-26 22:30:21 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-26 22:30:21 +0100 |
commit | ce0a4e3a85c40f2b46ce2ee5f58ab56c30f38d99 (patch) | |
tree | 0010e960666b6554328b31697d764eddded45461 /candle-nn/src | |
parent | 4abc1ea34dbc834e561f442737faf2c735f0a6ce (diff) | |
download | candle-ce0a4e3a85c40f2b46ce2ee5f58ab56c30f38d99.tar.gz candle-ce0a4e3a85c40f2b46ce2ee5f58ab56c30f38d99.tar.bz2 candle-ce0a4e3a85c40f2b46ce2ee5f58ab56c30f38d99.zip |
Use the gelu-erf activation. (#969)
Diffstat (limited to 'candle-nn/src')
-rw-r--r-- | candle-nn/src/activation.rs | 4 |
1 files changed, 1 insertions, 3 deletions
diff --git a/candle-nn/src/activation.rs b/candle-nn/src/activation.rs index 17467b31..1e67ed53 100644 --- a/candle-nn/src/activation.rs +++ b/candle-nn/src/activation.rs @@ -16,9 +16,7 @@ pub enum Activation { impl super::Module for Activation { fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> { match self { - Self::Gelu => xs.gelu(), - // TODO: This is "gelu_new", not the original "gelu". - // There's some small numerical difference: + Self::Gelu => xs.gelu_erf(), // https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78 Self::NewGelu => xs.gelu(), Self::Relu => xs.relu(), |