summaryrefslogtreecommitdiff
path: root/candle-examples/examples/xlm-roberta/main.rs
diff options
context:
space:
mode:
authorAkshay Ballal <61191840+akshayballal95@users.noreply.github.com>2024-12-30 11:16:57 +0100
committerGitHub <noreply@github.com>2024-12-30 11:16:57 +0100
commit91f1f019b13386f4df3e9b2826c982d10bcc497e (patch)
tree1acf1a2921592bbd62ac43ea1321376164b8724b /candle-examples/examples/xlm-roberta/main.rs
parentcd639131f04990c16bfc498ea347cb9df3d2374f (diff)
downloadcandle-91f1f019b13386f4df3e9b2826c982d10bcc497e.tar.gz
candle-91f1f019b13386f4df3e9b2826c982d10bcc497e.tar.bz2
candle-91f1f019b13386f4df3e9b2826c982d10bcc497e.zip
Added XLMRobertaModel for Reranking (#2686)
* add xlm-roberta-base * Add task enum for fill-mask and reranker in xlm-roberta example; update README and fix attention mask dimensions - Introduced a new `Task` enum to replace string task identifiers in the xlm-roberta example. - Updated the logic in `main.rs` to handle tasks using the new enum. - Enhanced README with example output for fill-mask task. - Fixed dimension retrieval in `prepare_4d_attention_mask` function for better clarity and safety. * Clippy fix. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-examples/examples/xlm-roberta/main.rs')
-rw-r--r--candle-examples/examples/xlm-roberta/main.rs277
1 files changed, 277 insertions, 0 deletions
diff --git a/candle-examples/examples/xlm-roberta/main.rs b/candle-examples/examples/xlm-roberta/main.rs
new file mode 100644
index 00000000..47ab44b0
--- /dev/null
+++ b/candle-examples/examples/xlm-roberta/main.rs
@@ -0,0 +1,277 @@
+use std::path::PathBuf;
+
+use anyhow::{Error as E, Result};
+use candle::{Device, Tensor};
+use candle_nn::VarBuilder;
+use candle_transformers::models::xlm_roberta::{
+ Config, XLMRobertaForMaskedLM, XLMRobertaForSequenceClassification,
+};
+use clap::{Parser, ValueEnum};
+use hf_hub::{api::sync::Api, Repo, RepoType};
+use tokenizers::{PaddingParams, Tokenizer};
+
+#[derive(Debug, Clone, ValueEnum)]
+enum Model {
+ BgeRerankerBase,
+ BgeRerankerLarge,
+ BgeRerankerBaseV2,
+ XLMRobertaBase,
+ XLMRobertaLarge,
+}
+
+#[derive(Debug, Clone, ValueEnum)]
+enum Task {
+ FillMask,
+ Reranker,
+}
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+
+ /// Enable tracing (generates a trace-timestamp.json file).
+ #[arg(long)]
+ tracing: bool,
+
+ /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
+ #[arg(long)]
+ model_id: Option<String>,
+
+ #[arg(long, default_value = "main")]
+ revision: String,
+
+ #[arg(long, default_value = "bge-reranker-base")]
+ model: Model,
+
+ #[arg(long, default_value = "reranker")]
+ task: Task,
+
+ // Path to the tokenizer file.
+ #[arg(long)]
+ tokenizer_file: Option<String>,
+
+ // Path to the weight files.
+ #[arg(long)]
+ weight_files: Option<String>,
+
+ // Path to the config file.
+ #[arg(long)]
+ config_file: Option<String>,
+
+ /// When set, compute embeddings for this prompt.
+ #[arg(long)]
+ prompt: Option<String>,
+}
+
+fn main() -> Result<()> {
+ let args = Args::parse();
+ let api = Api::new()?;
+ let model_id = match &args.model_id {
+ Some(model_id) => model_id.to_string(),
+ None => match args.task {
+ Task::FillMask => match args.model {
+ Model::XLMRobertaBase => "FacebookAI/xlm-roberta-base".to_string(),
+ Model::XLMRobertaLarge => "FacebookAI/xlm-roberta-large".to_string(),
+ _ => anyhow::bail!("BGE models are not supported for fill-mask task"),
+ },
+ Task::Reranker => match args.model {
+ Model::BgeRerankerBase => "BAAI/bge-reranker-base".to_string(),
+ Model::BgeRerankerLarge => "BAAI/bge-reranker-large".to_string(),
+ Model::BgeRerankerBaseV2 => "BAAI/bge-reranker-base-v2-m3".to_string(),
+ _ => anyhow::bail!("XLM-RoBERTa models are not supported for reranker task"),
+ },
+ },
+ };
+ let repo = api.repo(Repo::with_revision(
+ model_id,
+ RepoType::Model,
+ args.revision,
+ ));
+
+ let tokenizer_filename = match args.tokenizer_file {
+ Some(file) => std::path::PathBuf::from(file),
+ None => repo.get("tokenizer.json")?,
+ };
+
+ let config_filename = match args.config_file {
+ Some(file) => std::path::PathBuf::from(file),
+ None => repo.get("config.json")?,
+ };
+
+ let weights_filename = match args.weight_files {
+ Some(files) => PathBuf::from(files),
+ None => match repo.get("model.safetensors") {
+ Ok(safetensors) => safetensors,
+ Err(_) => match repo.get("pytorch_model.bin") {
+ Ok(pytorch_model) => pytorch_model,
+ Err(e) => {
+ return Err(anyhow::Error::msg(format!("Model weights not found. The weights should either be a `model.safetensors` or `pytorch_model.bin` file. Error: {}", e)));
+ }
+ },
+ },
+ };
+
+ let config = std::fs::read_to_string(config_filename)?;
+ let config: Config = serde_json::from_str(&config)?;
+ let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
+
+ let device = candle_examples::device(args.cpu)?;
+
+ let vb = if weights_filename.ends_with("model.safetensors") {
+ unsafe {
+ VarBuilder::from_mmaped_safetensors(&[weights_filename], candle::DType::F16, &device)
+ .unwrap()
+ }
+ } else {
+ println!("Loading weights from pytorch_model.bin");
+ VarBuilder::from_pth(&weights_filename, candle::DType::F16, &device).unwrap()
+ };
+ tokenizer
+ .with_padding(Some(PaddingParams {
+ strategy: tokenizers::PaddingStrategy::BatchLongest,
+ pad_id: config.pad_token_id,
+ ..Default::default()
+ }))
+ .with_truncation(None)
+ .map_err(E::msg)?;
+
+ match args.task {
+ Task::FillMask => {
+ let prompt = vec![
+ "Hello I'm a <mask> model.".to_string(),
+ "I'm a <mask> boy.".to_string(),
+ "I'm <mask> in berlin.".to_string(),
+ ];
+ let model = XLMRobertaForMaskedLM::new(&config, vb)?;
+
+ let input_ids = tokenize_batch(&tokenizer, TokenizeInput::Single(&prompt), &device)?;
+ let attention_mask =
+ get_attention_mask(&tokenizer, TokenizeInput::Single(&prompt), &device)?;
+
+ let token_type_ids = Tensor::zeros(input_ids.dims(), input_ids.dtype(), &device)?;
+
+ let output = model
+ .forward(
+ &input_ids,
+ &attention_mask,
+ &token_type_ids,
+ None,
+ None,
+ None,
+ )?
+ .to_dtype(candle::DType::F32)?;
+
+ let max_outs = output.argmax(2)?;
+
+ let max_out = max_outs.to_vec2::<u32>()?;
+ let max_out_refs: Vec<&[u32]> = max_out.iter().map(|v| v.as_slice()).collect();
+ let decoded = tokenizer.decode_batch(&max_out_refs, true).unwrap();
+ for (i, sentence) in decoded.iter().enumerate() {
+ println!("Sentence: {} : {}", i + 1, sentence);
+ }
+ }
+ Task::Reranker => {
+ let query = "what is panda?".to_string();
+
+ let documents = ["South Korea is a country in East Asia.".to_string(),
+ "There are forests in the mountains.".to_string(),
+ "Pandas look like bears.".to_string(),
+ "There are some animals with black and white fur.".to_string(),
+ "The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.".to_string()];
+
+ // create pairs of query and documents
+ let pairs = documents
+ .iter()
+ .map(|doc| (query.clone(), doc.clone()))
+ .collect::<Vec<_>>();
+ let input_ids = tokenize_batch(&tokenizer, TokenizeInput::Pairs(&pairs), &device)?;
+ let attention_mask =
+ get_attention_mask(&tokenizer, TokenizeInput::Pairs(&pairs), &device)?;
+ let token_type_ids = Tensor::zeros(input_ids.dims(), input_ids.dtype(), &device)?;
+
+ let model = XLMRobertaForSequenceClassification::new(1, &config, vb)?;
+
+ let output = model.forward(&input_ids, &attention_mask, &token_type_ids)?;
+ let output = candle_nn::ops::sigmoid(&output)?.t().unwrap();
+ let ranks = output
+ .arg_sort_last_dim(false)?
+ .to_vec2::<u32>()?
+ .into_iter()
+ .flatten()
+ .collect::<Vec<_>>();
+ println!("\nRanking Results:");
+ println!("{:-<80}", "");
+ documents.iter().enumerate().for_each(|(idx, doc)| {
+ let rank = ranks.iter().position(|&r| r == idx as u32).unwrap();
+ let score = output
+ .get_on_dim(1, idx)
+ .unwrap()
+ .to_dtype(candle::DType::F32)
+ .unwrap()
+ .to_vec1::<f32>()
+ .unwrap();
+ println!("Rank #{:<2} | Score: {:.4} | {}", rank + 1, score[0], doc);
+ });
+ println!("{:-<80}", "");
+ }
+ }
+ Ok(())
+}
+
+#[derive(Debug)]
+pub enum TokenizeInput<'a> {
+ Single(&'a [String]),
+ Pairs(&'a [(String, String)]),
+}
+
+pub fn tokenize_batch(
+ tokenizer: &Tokenizer,
+ input: TokenizeInput,
+ device: &Device,
+) -> anyhow::Result<Tensor> {
+ let tokens = match input {
+ TokenizeInput::Single(text_batch) => tokenizer
+ .encode_batch(text_batch.to_vec(), true)
+ .map_err(E::msg)?,
+ TokenizeInput::Pairs(pairs) => tokenizer
+ .encode_batch(pairs.to_vec(), true)
+ .map_err(E::msg)?,
+ };
+
+ let token_ids = tokens
+ .iter()
+ .map(|tokens| {
+ let tokens = tokens.get_ids().to_vec();
+ Tensor::new(tokens.as_slice(), device)
+ })
+ .collect::<candle::Result<Vec<_>>>()?;
+
+ Ok(Tensor::stack(&token_ids, 0)?)
+}
+
+pub fn get_attention_mask(
+ tokenizer: &Tokenizer,
+ input: TokenizeInput,
+ device: &Device,
+) -> anyhow::Result<Tensor> {
+ let tokens = match input {
+ TokenizeInput::Single(text_batch) => tokenizer
+ .encode_batch(text_batch.to_vec(), true)
+ .map_err(E::msg)?,
+ TokenizeInput::Pairs(pairs) => tokenizer
+ .encode_batch(pairs.to_vec(), true)
+ .map_err(E::msg)?,
+ };
+
+ let attention_mask = tokens
+ .iter()
+ .map(|tokens| {
+ let tokens = tokens.get_attention_mask().to_vec();
+ Tensor::new(tokens.as_slice(), device)
+ })
+ .collect::<candle::Result<Vec<_>>>()?;
+ Ok(Tensor::stack(&attention_mask, 0)?)
+}