diff options
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/src/activation.rs | 4 |
1 files changed, 4 insertions, 0 deletions
diff --git a/candle-nn/src/activation.rs b/candle-nn/src/activation.rs index 60a7a6d1..b9745375 100644 --- a/candle-nn/src/activation.rs +++ b/candle-nn/src/activation.rs @@ -5,6 +5,7 @@ use serde::Deserialize; #[serde(rename_all = "lowercase")] pub enum Activation { #[default] + #[serde(alias = "gelu")] Gelu, #[serde(alias = "gelu_new")] NewGelu, @@ -19,6 +20,8 @@ pub enum Activation { HardSwish, Elu(f64), LeakyRelu(f64), + #[serde(alias = "gelu_pytorch_tanh")] + GeluPytorchTanh, } impl super::Module for Activation { @@ -38,6 +41,7 @@ impl super::Module for Activation { Self::HardSwish => xs * crate::ops::hard_sigmoid(xs)?, &Self::Elu(alpha) => xs.elu(alpha), &Self::LeakyRelu(negative_slope) => crate::ops::leaky_relu(xs, negative_slope), + Self::GeluPytorchTanh => xs.gelu(), } } } |