summaryrefslogtreecommitdiff
path: root/candle-examples/examples/bert
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/bert')
-rw-r--r--candle-examples/examples/bert/main.rs186
1 files changed, 123 insertions, 63 deletions
diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs
index 4de0aeac..4396326d 100644
--- a/candle-examples/examples/bert/main.rs
+++ b/candle-examples/examples/bert/main.rs
@@ -5,6 +5,7 @@ use candle_hub::{api::Api, Cache, Repo, RepoType};
use clap::Parser;
use serde::Deserialize;
use std::collections::HashMap;
+use tokenizers::Tokenizer;
const DTYPE: DType = DType::F32;
@@ -578,6 +579,7 @@ impl BertEncoder {
struct BertModel {
embeddings: BertEmbeddings,
encoder: BertEncoder,
+ device: Device,
}
impl BertModel {
@@ -600,6 +602,7 @@ impl BertModel {
Ok(Self {
embeddings,
encoder,
+ device: vb.device.clone(),
})
}
@@ -628,81 +631,138 @@ struct Args {
#[arg(long)]
revision: Option<String>,
- /// The number of times to run the prompt.
- #[arg(long, default_value = "This is an example sentence")]
- prompt: String,
+ /// When set, compute embeddings for this prompt.
+ #[arg(long)]
+ prompt: Option<String>,
/// The number of times to run the prompt.
#[arg(long, default_value = "1")]
n: usize,
}
+impl Args {
+ async fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> {
+ let device = if self.cpu {
+ Device::Cpu
+ } else {
+ Device::new_cuda(0)?
+ };
+ let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
+ let default_revision = "refs/pr/21".to_string();
+ let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
+ (Some(model_id), Some(revision)) => (model_id, revision),
+ (Some(model_id), None) => (model_id, "main".to_string()),
+ (None, Some(revision)) => (default_model, revision),
+ (None, None) => (default_model, default_revision),
+ };
+
+ let repo = Repo::with_revision(model_id, RepoType::Model, revision);
+ let (config_filename, tokenizer_filename, weights_filename) = if self.offline {
+ let cache = Cache::default();
+ (
+ cache
+ .get(&repo, "config.json")
+ .ok_or(anyhow!("Missing config file in cache"))?,
+ cache
+ .get(&repo, "tokenizer.json")
+ .ok_or(anyhow!("Missing tokenizer file in cache"))?,
+ cache
+ .get(&repo, "model.safetensors")
+ .ok_or(anyhow!("Missing weights file in cache"))?,
+ )
+ } else {
+ let api = Api::new()?;
+ (
+ api.get(&repo, "config.json").await?,
+ api.get(&repo, "tokenizer.json").await?,
+ api.get(&repo, "model.safetensors").await?,
+ )
+ };
+ let config = std::fs::read_to_string(config_filename)?;
+ let config: Config = serde_json::from_str(&config)?;
+ let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
+
+ let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
+ let weights = weights.deserialize()?;
+ let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device);
+ let model = BertModel::load(&vb, &config)?;
+ Ok((model, tokenizer))
+ }
+}
+
#[tokio::main]
async fn main() -> Result<()> {
- use tokenizers::Tokenizer;
let start = std::time::Instant::now();
let args = Args::parse();
- let device = if args.cpu {
- Device::Cpu
- } else {
- Device::new_cuda(0)?
- };
-
- let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
- let default_revision = "refs/pr/21".to_string();
- let (model_id, revision) = match (args.model_id, args.revision) {
- (Some(model_id), Some(revision)) => (model_id, revision),
- (Some(model_id), None) => (model_id, "main".to_string()),
- (None, Some(revision)) => (default_model, revision),
- (None, None) => (default_model, default_revision),
- };
-
- let repo = Repo::with_revision(model_id, RepoType::Model, revision);
- let (config_filename, tokenizer_filename, weights_filename) = if args.offline {
- let cache = Cache::default();
- (
- cache
- .get(&repo, "config.json")
- .ok_or(anyhow!("Missing config file in cache"))?,
- cache
- .get(&repo, "tokenizer.json")
- .ok_or(anyhow!("Missing tokenizer file in cache"))?,
- cache
- .get(&repo, "model.safetensors")
- .ok_or(anyhow!("Missing weights file in cache"))?,
- )
+ let (model, mut tokenizer) = args.build_model_and_tokenizer().await?;
+ let device = &model.device;
+
+ if let Some(prompt) = args.prompt {
+ let tokenizer = tokenizer.with_padding(None).with_truncation(None);
+ let tokens = tokenizer
+ .encode(prompt, true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+ let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
+ let token_type_ids = token_ids.zeros_like()?;
+ println!("Loaded and encoded {:?}", start.elapsed());
+ for _ in 0..args.n {
+ let start = std::time::Instant::now();
+ let _ys = model.forward(&token_ids, &token_type_ids)?;
+ println!("Took {:?}", start.elapsed());
+ }
} else {
- let api = Api::new()?;
- (
- api.get(&repo, "config.json").await?,
- api.get(&repo, "tokenizer.json").await?,
- api.get(&repo, "model.safetensors").await?,
- )
- };
- let config = std::fs::read_to_string(config_filename)?;
- let config: Config = serde_json::from_str(&config)?;
- let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
- let tokenizer = tokenizer.with_padding(None).with_truncation(None);
-
- let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
- let weights = weights.deserialize()?;
- let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device.clone());
- let model = BertModel::load(&vb, &config)?;
-
- let tokens = tokenizer
- .encode(args.prompt, true)
- .map_err(E::msg)?
- .get_ids()
- .to_vec();
- let token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?;
- let token_type_ids = token_ids.zeros_like()?;
- println!("Loaded and encoded {:?}", start.elapsed());
- for _ in 0..args.n {
- let start = std::time::Instant::now();
- let _ys = model.forward(&token_ids, &token_type_ids)?;
- println!("Took {:?}", start.elapsed());
- // println!("Ys {:?}", ys.shape());
+ let sentences = [
+ "The cat sits outside",
+ "A man is playing guitar",
+ "I love pasta",
+ "The new movie is awesome",
+ "The cat plays in the garden",
+ "A woman watches TV",
+ "The new movie is so great",
+ "Do you like pizza?",
+ ];
+ let n_sentences = sentences.len();
+ if let Some(pp) = tokenizer.get_padding_mut() {
+ pp.strategy = tokenizers::PaddingStrategy::BatchLongest
+ }
+ let tokens = tokenizer
+ .encode_batch(sentences.to_vec(), true)
+ .map_err(E::msg)?;
+ let token_ids = tokens
+ .iter()
+ .map(|tokens| {
+ let tokens = tokens.get_ids().to_vec();
+ Ok(Tensor::new(tokens.as_slice(), device)?)
+ })
+ .collect::<Result<Vec<_>>>()?;
+ let token_ids = Tensor::stack(&token_ids, 0)?;
+ let token_type_ids = token_ids.zeros_like()?;
+ println!("running inference on batch {:?}", token_ids.shape());
+ let embeddings = model.forward(&token_ids, &token_type_ids)?;
+ println!("generated embeddings {:?}", embeddings.shape());
+ // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
+ let (_n_sentence, n_tokens, _hidden_size) = embeddings.shape().r3()?;
+ let embeddings = (embeddings.sum(&[1])? / (n_tokens as f64))?.squeeze(1)?;
+ println!("pooled embeddings {:?}", embeddings.shape());
+ let mut similarities = vec![];
+ for i in 0..n_sentences {
+ let e_i = embeddings.get(i)?;
+ for j in (i + 1)..n_sentences {
+ let e_j = embeddings.get(j)?;
+ let sum_ij = (&e_i * &e_j)?.sum_all()?.reshape(())?.to_scalar::<f32>()?;
+ let sum_i2 = (&e_i * &e_i)?.sum_all()?.reshape(())?.to_scalar::<f32>()?;
+ let sum_j2 = (&e_j * &e_j)?.sum_all()?.reshape(())?.to_scalar::<f32>()?;
+ let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
+ similarities.push((cosine_similarity, i, j))
+ }
+ }
+ similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
+ for &(score, i, j) in similarities[..5].iter() {
+ println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
+ }
}
Ok(())
}