summaryrefslogtreecommitdiff
path: root/candle-nn/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-26 22:30:21 +0100
committerGitHub <noreply@github.com>2023-09-26 22:30:21 +0100
commitce0a4e3a85c40f2b46ce2ee5f58ab56c30f38d99 (patch)
tree0010e960666b6554328b31697d764eddded45461 /candle-nn/src
parent4abc1ea34dbc834e561f442737faf2c735f0a6ce (diff)
downloadcandle-ce0a4e3a85c40f2b46ce2ee5f58ab56c30f38d99.tar.gz
candle-ce0a4e3a85c40f2b46ce2ee5f58ab56c30f38d99.tar.bz2
candle-ce0a4e3a85c40f2b46ce2ee5f58ab56c30f38d99.zip
Use the gelu-erf activation. (#969)
Diffstat (limited to 'candle-nn/src')
-rw-r--r--candle-nn/src/activation.rs4
1 files changed, 1 insertions, 3 deletions
diff --git a/candle-nn/src/activation.rs b/candle-nn/src/activation.rs
index 17467b31..1e67ed53 100644
--- a/candle-nn/src/activation.rs
+++ b/candle-nn/src/activation.rs
@@ -16,9 +16,7 @@ pub enum Activation {
impl super::Module for Activation {
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
match self {
- Self::Gelu => xs.gelu(),
- // TODO: This is "gelu_new", not the original "gelu".
- // There's some small numerical difference:
+ Self::Gelu => xs.gelu_erf(),
// https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78
Self::NewGelu => xs.gelu(),
Self::Relu => xs.relu(),