#[cfg(feature = "mkl")] extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; use candle_transformers::models::jina_bert::{BertModel, Config, PositionEmbeddingType}; use anyhow::Error as E; use candle::{DType, Module, Tensor}; use candle_nn::VarBuilder; use clap::Parser; #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { /// Run on CPU rather than on GPU. #[arg(long)] cpu: bool, /// Enable tracing (generates a trace-timestamp.json file). #[arg(long)] tracing: bool, /// When set, compute embeddings for this prompt. #[arg(long)] prompt: Option, /// The number of times to run the prompt. #[arg(long, default_value = "1")] n: usize, /// L2 normalization for embeddings. #[arg(long, default_value = "true")] normalize_embeddings: bool, #[arg(long)] tokenizer: Option, #[arg(long)] model: Option, #[arg(long)] model_file: Option, } impl Args { fn build_model_and_tokenizer(&self) -> anyhow::Result<(BertModel, tokenizers::Tokenizer)> { use hf_hub::{api::sync::Api, Repo, RepoType}; let model_name = match self.model.as_ref() { Some(model) => model.to_string(), None => "jinaai/jina-embeddings-v2-base-en".to_string(), }; let model = match &self.model_file { Some(model_file) => std::path::PathBuf::from(model_file), None => Api::new()? .repo(Repo::new(model_name.to_string(), RepoType::Model)) .get("model.safetensors")?, }; let tokenizer = match &self.tokenizer { Some(file) => std::path::PathBuf::from(file), None => Api::new()? .repo(Repo::new(model_name.to_string(), RepoType::Model)) .get("tokenizer.json")?, }; let device = candle_examples::device(self.cpu)?; let tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(E::msg)?; let config = Config::new( tokenizer.get_vocab_size(true), 768, 12, 12, 3072, candle_nn::Activation::Gelu, 8192, 2, 0.02, 1e-12, 0, PositionEmbeddingType::Alibi, ); let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }; let model = BertModel::new(vb, &config)?; Ok((model, tokenizer)) } } fn main() -> anyhow::Result<()> { use tracing_chrome::ChromeLayerBuilder; use tracing_subscriber::prelude::*; let args = Args::parse(); let _guard = if args.tracing { println!("tracing..."); let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); tracing_subscriber::registry().with(chrome_layer).init(); Some(guard) } else { None }; let start = std::time::Instant::now(); let (model, mut tokenizer) = args.build_model_and_tokenizer()?; let device = &model.device; if let Some(prompt) = args.prompt { let tokenizer = tokenizer .with_padding(None) .with_truncation(None) .map_err(E::msg)?; let tokens = tokenizer .encode(prompt, true) .map_err(E::msg)? .get_ids() .to_vec(); let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; println!("Loaded and encoded {:?}", start.elapsed()); let start = std::time::Instant::now(); let embeddings = model.forward(&token_ids)?; let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?; let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?; println!("pooled_embeddigns: {embeddings}"); let embeddings = if args.normalize_embeddings { normalize_l2(&embeddings)? } else { embeddings }; if args.normalize_embeddings { println!("normalized_embeddings: {embeddings}"); } println!("Took {:?}", start.elapsed()); } else { let sentences = [ "The cat sits outside", "A man is playing guitar", "I love pasta", "The new movie is awesome", "The cat plays in the garden", "A woman watches TV", "The new movie is so great", "Do you like pizza?", ]; let n_sentences = sentences.len(); if let Some(pp) = tokenizer.get_padding_mut() { pp.strategy = tokenizers::PaddingStrategy::BatchLongest } else { let pp = tokenizers::PaddingParams { strategy: tokenizers::PaddingStrategy::BatchLongest, ..Default::default() }; tokenizer.with_padding(Some(pp)); } let tokens = tokenizer .encode_batch(sentences.to_vec(), true) .map_err(E::msg)?; let token_ids = tokens .iter() .map(|tokens| { let tokens = tokens.get_ids().to_vec(); Tensor::new(tokens.as_slice(), device) }) .collect::>>()?; let token_ids = Tensor::stack(&token_ids, 0)?; println!("running inference on batch {:?}", token_ids.shape()); let embeddings = model.forward(&token_ids)?; println!("generated embeddings {:?}", embeddings.shape()); // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?; let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?; let embeddings = if args.normalize_embeddings { normalize_l2(&embeddings)? } else { embeddings }; println!("pooled embeddings {:?}", embeddings.shape()); let mut similarities = vec![]; for i in 0..n_sentences { let e_i = embeddings.get(i)?; for j in (i + 1)..n_sentences { let e_j = embeddings.get(j)?; let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::()?; let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::()?; let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::()?; let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt(); similarities.push((cosine_similarity, i, j)) } } similarities.sort_by(|u, v| v.0.total_cmp(&u.0)); for &(score, i, j) in similarities[..5].iter() { println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j]) } } Ok(()) } pub fn normalize_l2(v: &Tensor) -> candle::Result { v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?) }