#[cfg(feature = "mkl")] extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; use candle_transformers::models::distilbert::{Config, DistilBertModel, DTYPE}; use anyhow::{Error as E, Result}; use candle::{Device, Tensor}; use candle_nn::VarBuilder; use clap::Parser; use hf_hub::{api::sync::Api, Repo, RepoType}; use tokenizers::Tokenizer; #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { /// Run on CPU rather than on GPU. #[arg(long)] cpu: bool, /// Enable tracing (generates a trace-timestamp.json file). #[arg(long)] tracing: bool, /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending #[arg(long)] model_id: Option, #[arg(long)] revision: Option, /// When set, compute embeddings for this prompt. #[arg(long)] prompt: String, /// Use the pytorch weights rather than the safetensors ones #[arg(long)] use_pth: bool, /// The number of times to run the prompt. #[arg(long, default_value = "1")] n: usize, /// L2 normalization for embeddings. #[arg(long, default_value = "true")] normalize_embeddings: bool, } impl Args { fn build_model_and_tokenizer(&self) -> Result<(DistilBertModel, Tokenizer)> { let device = candle_examples::device(self.cpu)?; let default_model = "distilbert-base-uncased".to_string(); let default_revision = "main".to_string(); let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) { (Some(model_id), Some(revision)) => (model_id, revision), (Some(model_id), None) => (model_id, "main".to_string()), (None, Some(revision)) => (default_model, revision), (None, None) => (default_model, default_revision), }; let repo = Repo::with_revision(model_id, RepoType::Model, revision); let (config_filename, tokenizer_filename, weights_filename) = { let api = Api::new()?; let api = api.repo(repo); let config = api.get("config.json")?; let tokenizer = api.get("tokenizer.json")?; let weights = if self.use_pth { api.get("pytorch_model.bin")? } else { api.get("model.safetensors")? }; (config, tokenizer, weights) }; let config = std::fs::read_to_string(config_filename)?; let config: Config = serde_json::from_str(&config)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let vb = if self.use_pth { VarBuilder::from_pth(&weights_filename, DTYPE, &device)? } else { unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? } }; let model = DistilBertModel::load(vb, &config)?; Ok((model, tokenizer)) } } fn get_mask(size: usize, device: &Device) -> Tensor { let mask: Vec<_> = (0..size) .flat_map(|i| (0..size).map(move |j| u8::from(j > i))) .collect(); Tensor::from_slice(&mask, (size, size), device).unwrap() } fn main() -> Result<()> { use tracing_chrome::ChromeLayerBuilder; use tracing_subscriber::prelude::*; let args = Args::parse(); let _guard = if args.tracing { println!("tracing..."); let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); tracing_subscriber::registry().with(chrome_layer).init(); Some(guard) } else { None }; let (model, mut tokenizer) = args.build_model_and_tokenizer()?; let device = &model.device; let tokenizer = tokenizer .with_padding(None) .with_truncation(None) .map_err(E::msg)?; let tokens = tokenizer .encode(args.prompt, true) .map_err(E::msg)? .get_ids() .to_vec(); let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; let mask = get_mask(tokens.len(), device); println!("token_ids: {:?}", token_ids.to_vec2::()); println!("mask: {:?}", mask.to_vec2::()); let ys = model.forward(&token_ids, &mask)?; println!("{ys}"); Ok(()) } pub fn normalize_l2(v: &Tensor) -> Result { Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?) }