diff options
Diffstat (limited to 'candle-examples/examples/bert/main.rs')
-rw-r--r-- | candle-examples/examples/bert/main.rs | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index fcd2eab9..88e29718 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -3,7 +3,7 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; -use candle_transformers::models::bert::{BertModel, Config, DTYPE}; +use candle_transformers::models::bert::{BertModel, Config, HiddenAct, DTYPE}; use anyhow::{Error as E, Result}; use candle::Tensor; @@ -45,6 +45,10 @@ struct Args { /// L2 normalization for embeddings. #[arg(long, default_value = "true")] normalize_embeddings: bool, + + /// Use tanh based approximation for Gelu instead of erf implementation. + #[arg(long, default_value = "false")] + approximate_gelu: bool, } impl Args { @@ -73,7 +77,7 @@ impl Args { (config, tokenizer, weights) }; let config = std::fs::read_to_string(config_filename)?; - let config: Config = serde_json::from_str(&config)?; + let mut config: Config = serde_json::from_str(&config)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let vb = if self.use_pth { @@ -81,6 +85,9 @@ impl Args { } else { unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? } }; + if self.approximate_gelu { + config.hidden_act = HiddenAct::GeluApproximate; + } let model = BertModel::load(vb, &config)?; Ok((model, tokenizer)) } |