diff options
Diffstat (limited to 'candle-transformers/src/models/bert.rs')
-rw-r--r-- | candle-transformers/src/models/bert.rs | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs index d6826a16..51c524f5 100644 --- a/candle-transformers/src/models/bert.rs +++ b/candle-transformers/src/models/bert.rs @@ -7,8 +7,9 @@ pub const DTYPE: DType = DType::F32; #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] #[serde(rename_all = "lowercase")] -enum HiddenAct { +pub enum HiddenAct { Gelu, + GeluApproximate, Relu, } @@ -28,6 +29,7 @@ impl HiddenActLayer { match self.act { // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213 HiddenAct::Gelu => xs.gelu_erf(), + HiddenAct::GeluApproximate => xs.gelu(), HiddenAct::Relu => xs.relu(), } } @@ -48,7 +50,7 @@ pub struct Config { num_hidden_layers: usize, num_attention_heads: usize, intermediate_size: usize, - hidden_act: HiddenAct, + pub hidden_act: HiddenAct, hidden_dropout_prob: f64, max_position_embeddings: usize, type_vocab_size: usize, |