summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/bert.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/bert.rs')
-rw-r--r--candle-transformers/src/models/bert.rs6
1 files changed, 4 insertions, 2 deletions
diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs
index d6826a16..51c524f5 100644
--- a/candle-transformers/src/models/bert.rs
+++ b/candle-transformers/src/models/bert.rs
@@ -7,8 +7,9 @@ pub const DTYPE: DType = DType::F32;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "lowercase")]
-enum HiddenAct {
+pub enum HiddenAct {
Gelu,
+ GeluApproximate,
Relu,
}
@@ -28,6 +29,7 @@ impl HiddenActLayer {
match self.act {
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
HiddenAct::Gelu => xs.gelu_erf(),
+ HiddenAct::GeluApproximate => xs.gelu(),
HiddenAct::Relu => xs.relu(),
}
}
@@ -48,7 +50,7 @@ pub struct Config {
num_hidden_layers: usize,
num_attention_heads: usize,
intermediate_size: usize,
- hidden_act: HiddenAct,
+ pub hidden_act: HiddenAct,
hidden_dropout_prob: f64,
max_position_embeddings: usize,
type_vocab_size: usize,