summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/bert/main.rs18
1 files changed, 11 insertions, 7 deletions
diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs
index 9c9dc206..4de0aeac 100644
--- a/candle-examples/examples/bert/main.rs
+++ b/candle-examples/examples/bert/main.rs
@@ -621,18 +621,26 @@ struct Args {
#[arg(long)]
offline: bool,
+ /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
+
+ /// The number of times to run the prompt.
+ #[arg(long, default_value = "This is an example sentence")]
+ prompt: String,
+
+ /// The number of times to run the prompt.
+ #[arg(long, default_value = "1")]
+ n: usize,
}
#[tokio::main]
async fn main() -> Result<()> {
use tokenizers::Tokenizer;
let start = std::time::Instant::now();
- println!("Building {:?}", start.elapsed());
let args = Args::parse();
let device = if args.cpu {
@@ -672,29 +680,25 @@ async fn main() -> Result<()> {
api.get(&repo, "model.safetensors").await?,
)
};
- println!("Building {:?}", start.elapsed());
let config = std::fs::read_to_string(config_filename)?;
let config: Config = serde_json::from_str(&config)?;
- println!("Config loaded {:?}", start.elapsed());
let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let tokenizer = tokenizer.with_padding(None).with_truncation(None);
- println!("Tokenizer loaded {:?}", start.elapsed());
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device.clone());
let model = BertModel::load(&vb, &config)?;
- println!("Loaded {:?}", start.elapsed());
let tokens = tokenizer
- .encode("This is an example sentence", true)
+ .encode(args.prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?;
let token_type_ids = token_ids.zeros_like()?;
println!("Loaded and encoded {:?}", start.elapsed());
- for _ in 0..100 {
+ for _ in 0..args.n {
let start = std::time::Instant::now();
let _ys = model.forward(&token_ids, &token_type_ids)?;
println!("Took {:?}", start.elapsed());