summaryrefslogtreecommitdiff
path: root/candle-examples/examples/bert/main.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-05 13:06:42 +0100
committerGitHub <noreply@github.com>2023-07-05 13:06:42 +0100
commite4fb8c45cc7a30de4aaf365ebc1221a53a4db157 (patch)
treefee5a01b56231a6d1472fd925f76c73aa8b93ac0 /candle-examples/examples/bert/main.rs
parentbce28ab7938b27931fd51e59c8bcad37038e0337 (diff)
parent93896f6596e44285f6250f4966ada8c08fa85f09 (diff)
downloadcandle-e4fb8c45cc7a30de4aaf365ebc1221a53a4db157.tar.gz
candle-e4fb8c45cc7a30de4aaf365ebc1221a53a4db157.tar.bz2
candle-e4fb8c45cc7a30de4aaf365ebc1221a53a4db157.zip
Merge pull request #69 from LaurentMazare/upgrade_bert
Upgrading bert example to work with `bert-base-uncased`.
Diffstat (limited to 'candle-examples/examples/bert/main.rs')
-rw-r--r--candle-examples/examples/bert/main.rs121
1 files changed, 100 insertions, 21 deletions
diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs
index e5801314..4de0aeac 100644
--- a/candle-examples/examples/bert/main.rs
+++ b/candle-examples/examples/bert/main.rs
@@ -1,10 +1,9 @@
#![allow(dead_code)]
-// The tokenizer.json and weights should be retrieved from:
-// https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
-
-use anyhow::{Error as E, Result};
+use anyhow::{anyhow, Error as E, Result};
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
+use candle_hub::{api::Api, Cache, Repo, RepoType};
use clap::Parser;
+use serde::Deserialize;
use std::collections::HashMap;
const DTYPE: DType = DType::F32;
@@ -66,7 +65,8 @@ impl<'a> VarBuilder<'a> {
}
}
-#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
+#[serde(rename_all = "lowercase")]
enum HiddenAct {
Gelu,
Relu,
@@ -84,13 +84,14 @@ impl HiddenAct {
}
}
-#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
+#[serde(rename_all = "lowercase")]
enum PositionEmbeddingType {
Absolute,
}
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1
-#[derive(Debug, Clone, PartialEq)]
+#[derive(Debug, Clone, PartialEq, Deserialize)]
struct Config {
vocab_size: usize,
hidden_size: usize,
@@ -235,8 +236,22 @@ impl LayerNorm {
}
fn load(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<Self> {
- let weight = vb.get(size, &format!("{p}.weight"))?;
- let bias = vb.get(size, &format!("{p}.bias"))?;
+ let (weight, bias) = match (
+ vb.get(size, &format!("{p}.weight")),
+ vb.get(size, &format!("{p}.bias")),
+ ) {
+ (Ok(weight), Ok(bias)) => (weight, bias),
+ (Err(err), _) | (_, Err(err)) => {
+ if let (Ok(weight), Ok(bias)) = (
+ vb.get(size, &format!("{p}.gamma")),
+ vb.get(size, &format!("{p}.beta")),
+ ) {
+ (weight, bias)
+ } else {
+ return Err(err.into());
+ }
+ }
+ };
Ok(Self { weight, bias, eps })
}
@@ -567,8 +582,21 @@ struct BertModel {
impl BertModel {
fn load(vb: &VarBuilder, config: &Config) -> Result<Self> {
- let embeddings = BertEmbeddings::load("embeddings", vb, config)?;
- let encoder = BertEncoder::load("encoder", vb, config)?;
+ let (embeddings, encoder) = match (
+ BertEmbeddings::load("embeddings", vb, config),
+ BertEncoder::load("encoder", vb, config),
+ ) {
+ (Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
+ (Err(err), _) | (_, Err(err)) => {
+ match (
+ BertEmbeddings::load("bert.embeddings", vb, config),
+ BertEncoder::load("bert.encoder", vb, config),
+ ) {
+ (Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
+ _ => return Err(err),
+ }
+ }
+ };
Ok(Self {
embeddings,
encoder,
@@ -589,15 +617,30 @@ struct Args {
#[arg(long)]
cpu: bool,
+ /// Run offline (you must have the files already cached)
+ #[arg(long)]
+ offline: bool,
+
+ /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
#[arg(long)]
- tokenizer_config: String,
+ model_id: Option<String>,
#[arg(long)]
- weights: String,
+ revision: Option<String>,
+
+ /// The number of times to run the prompt.
+ #[arg(long, default_value = "This is an example sentence")]
+ prompt: String,
+
+ /// The number of times to run the prompt.
+ #[arg(long, default_value = "1")]
+ n: usize,
}
-fn main() -> Result<()> {
+#[tokio::main]
+async fn main() -> Result<()> {
use tokenizers::Tokenizer;
+ let start = std::time::Instant::now();
let args = Args::parse();
let device = if args.cpu {
@@ -606,24 +649,60 @@ fn main() -> Result<()> {
Device::new_cuda(0)?
};
- let mut tokenizer = Tokenizer::from_file(args.tokenizer_config).map_err(E::msg)?;
+ let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
+ let default_revision = "refs/pr/21".to_string();
+ let (model_id, revision) = match (args.model_id, args.revision) {
+ (Some(model_id), Some(revision)) => (model_id, revision),
+ (Some(model_id), None) => (model_id, "main".to_string()),
+ (None, Some(revision)) => (default_model, revision),
+ (None, None) => (default_model, default_revision),
+ };
+
+ let repo = Repo::with_revision(model_id, RepoType::Model, revision);
+ let (config_filename, tokenizer_filename, weights_filename) = if args.offline {
+ let cache = Cache::default();
+ (
+ cache
+ .get(&repo, "config.json")
+ .ok_or(anyhow!("Missing config file in cache"))?,
+ cache
+ .get(&repo, "tokenizer.json")
+ .ok_or(anyhow!("Missing tokenizer file in cache"))?,
+ cache
+ .get(&repo, "model.safetensors")
+ .ok_or(anyhow!("Missing weights file in cache"))?,
+ )
+ } else {
+ let api = Api::new()?;
+ (
+ api.get(&repo, "config.json").await?,
+ api.get(&repo, "tokenizer.json").await?,
+ api.get(&repo, "model.safetensors").await?,
+ )
+ };
+ let config = std::fs::read_to_string(config_filename)?;
+ let config: Config = serde_json::from_str(&config)?;
+ let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let tokenizer = tokenizer.with_padding(None).with_truncation(None);
- let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };
+ let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device.clone());
- let config = Config::all_mini_lm_l6_v2();
let model = BertModel::load(&vb, &config)?;
let tokens = tokenizer
- .encode("This is an example sentence", true)
+ .encode(args.prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?;
- println!("{token_ids}");
let token_type_ids = token_ids.zeros_like()?;
- let ys = model.forward(&token_ids, &token_type_ids)?;
- println!("{ys}");
+ println!("Loaded and encoded {:?}", start.elapsed());
+ for _ in 0..args.n {
+ let start = std::time::Instant::now();
+ let _ys = model.forward(&token_ids, &token_type_ids)?;
+ println!("Took {:?}", start.elapsed());
+ // println!("Ys {:?}", ys.shape());
+ }
Ok(())
}