summaryrefslogtreecommitdiff
path: root/candle-examples/examples/bert/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/bert/main.rs')
-rw-r--r--candle-examples/examples/bert/main.rs11
1 files changed, 9 insertions, 2 deletions
diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs
index fcd2eab9..88e29718 100644
--- a/candle-examples/examples/bert/main.rs
+++ b/candle-examples/examples/bert/main.rs
@@ -3,7 +3,7 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
-use candle_transformers::models::bert::{BertModel, Config, DTYPE};
+use candle_transformers::models::bert::{BertModel, Config, HiddenAct, DTYPE};
use anyhow::{Error as E, Result};
use candle::Tensor;
@@ -45,6 +45,10 @@ struct Args {
/// L2 normalization for embeddings.
#[arg(long, default_value = "true")]
normalize_embeddings: bool,
+
+ /// Use tanh based approximation for Gelu instead of erf implementation.
+ #[arg(long, default_value = "false")]
+ approximate_gelu: bool,
}
impl Args {
@@ -73,7 +77,7 @@ impl Args {
(config, tokenizer, weights)
};
let config = std::fs::read_to_string(config_filename)?;
- let config: Config = serde_json::from_str(&config)?;
+ let mut config: Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let vb = if self.use_pth {
@@ -81,6 +85,9 @@ impl Args {
} else {
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
};
+ if self.approximate_gelu {
+ config.hidden_act = HiddenAct::GeluApproximate;
+ }
let model = BertModel::load(vb, &config)?;
Ok((model, tokenizer))
}