diff options
author | Akshay Ballal <61191840+akshayballal95@users.noreply.github.com> | 2024-12-30 11:16:57 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-12-30 11:16:57 +0100 |
commit | 91f1f019b13386f4df3e9b2826c982d10bcc497e (patch) | |
tree | 1acf1a2921592bbd62ac43ea1321376164b8724b | |
parent | cd639131f04990c16bfc498ea347cb9df3d2374f (diff) | |
download | candle-91f1f019b13386f4df3e9b2826c982d10bcc497e.tar.gz candle-91f1f019b13386f4df3e9b2826c982d10bcc497e.tar.bz2 candle-91f1f019b13386f4df3e9b2826c982d10bcc497e.zip |
Added XLMRobertaModel for Reranking (#2686)
* add xlm-roberta-base
* Add task enum for fill-mask and reranker in xlm-roberta example; update README and fix attention mask dimensions
- Introduced a new `Task` enum to replace string task identifiers in the xlm-roberta example.
- Updated the logic in `main.rs` to handle tasks using the new enum.
- Enhanced README with example output for fill-mask task.
- Fixed dimension retrieval in `prepare_4d_attention_mask` function for better clarity and safety.
* Clippy fix.
---------
Co-authored-by: laurent <laurent.mazare@gmail.com>
-rw-r--r-- | candle-examples/examples/xlm-roberta/Readme.md | 30 | ||||
-rw-r--r-- | candle-examples/examples/xlm-roberta/main.rs | 277 | ||||
-rw-r--r-- | candle-transformers/src/models/mod.rs | 1 | ||||
-rw-r--r-- | candle-transformers/src/models/xlm_roberta.rs | 545 |
4 files changed, 853 insertions, 0 deletions
diff --git a/candle-examples/examples/xlm-roberta/Readme.md b/candle-examples/examples/xlm-roberta/Readme.md new file mode 100644 index 00000000..496b14e3 --- /dev/null +++ b/candle-examples/examples/xlm-roberta/Readme.md @@ -0,0 +1,30 @@ +# candle-xlm-roberta + +This example demonstrates how to use the XLM-RoBERTa model in Candle especially known for their use in reranking. It uses the `fill-mask` task to generate a word for a masked token. And a `reranker` task to rerank a list of documents for a given query. + +## Usage + +Fill Mask: +```bash +cargo run --example xlm-roberta --release -- --task fill-mask --model xlm-roberta-base +``` +```markdown +Sentence: 0 : Hello I'm a fashion model. +Sentence: 1 : I'm a little boy. +Sentence: 2 : I'm living in berlin. +``` + +Reranker: +```bash +cargo run --example xlm-roberta --release -- --task reranker --model bge-reranker-base +``` +```markdown +Ranking Results: +-------------------------------------------------------------------------------- +> Rank #4 | Score: 0.0001 | South Korea is a country in East Asia. +> Rank #5 | Score: 0.0000 | There are forests in the mountains. +> Rank #2 | Score: 0.7314 | Pandas look like bears. +> Rank #3 | Score: 0.6948 | There are some animals with black and white fur. +> Rank #1 | Score: 0.9990 | The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China. +-------------------------------------------------------------------------------- +``` diff --git a/candle-examples/examples/xlm-roberta/main.rs b/candle-examples/examples/xlm-roberta/main.rs new file mode 100644 index 00000000..47ab44b0 --- /dev/null +++ b/candle-examples/examples/xlm-roberta/main.rs @@ -0,0 +1,277 @@ +use std::path::PathBuf; + +use anyhow::{Error as E, Result}; +use candle::{Device, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::models::xlm_roberta::{ + Config, XLMRobertaForMaskedLM, XLMRobertaForSequenceClassification, +}; +use clap::{Parser, ValueEnum}; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use tokenizers::{PaddingParams, Tokenizer}; + +#[derive(Debug, Clone, ValueEnum)] +enum Model { + BgeRerankerBase, + BgeRerankerLarge, + BgeRerankerBaseV2, + XLMRobertaBase, + XLMRobertaLarge, +} + +#[derive(Debug, Clone, ValueEnum)] +enum Task { + FillMask, + Reranker, +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending + #[arg(long)] + model_id: Option<String>, + + #[arg(long, default_value = "main")] + revision: String, + + #[arg(long, default_value = "bge-reranker-base")] + model: Model, + + #[arg(long, default_value = "reranker")] + task: Task, + + // Path to the tokenizer file. + #[arg(long)] + tokenizer_file: Option<String>, + + // Path to the weight files. + #[arg(long)] + weight_files: Option<String>, + + // Path to the config file. + #[arg(long)] + config_file: Option<String>, + + /// When set, compute embeddings for this prompt. + #[arg(long)] + prompt: Option<String>, +} + +fn main() -> Result<()> { + let args = Args::parse(); + let api = Api::new()?; + let model_id = match &args.model_id { + Some(model_id) => model_id.to_string(), + None => match args.task { + Task::FillMask => match args.model { + Model::XLMRobertaBase => "FacebookAI/xlm-roberta-base".to_string(), + Model::XLMRobertaLarge => "FacebookAI/xlm-roberta-large".to_string(), + _ => anyhow::bail!("BGE models are not supported for fill-mask task"), + }, + Task::Reranker => match args.model { + Model::BgeRerankerBase => "BAAI/bge-reranker-base".to_string(), + Model::BgeRerankerLarge => "BAAI/bge-reranker-large".to_string(), + Model::BgeRerankerBaseV2 => "BAAI/bge-reranker-base-v2-m3".to_string(), + _ => anyhow::bail!("XLM-RoBERTa models are not supported for reranker task"), + }, + }, + }; + let repo = api.repo(Repo::with_revision( + model_id, + RepoType::Model, + args.revision, + )); + + let tokenizer_filename = match args.tokenizer_file { + Some(file) => std::path::PathBuf::from(file), + None => repo.get("tokenizer.json")?, + }; + + let config_filename = match args.config_file { + Some(file) => std::path::PathBuf::from(file), + None => repo.get("config.json")?, + }; + + let weights_filename = match args.weight_files { + Some(files) => PathBuf::from(files), + None => match repo.get("model.safetensors") { + Ok(safetensors) => safetensors, + Err(_) => match repo.get("pytorch_model.bin") { + Ok(pytorch_model) => pytorch_model, + Err(e) => { + return Err(anyhow::Error::msg(format!("Model weights not found. The weights should either be a `model.safetensors` or `pytorch_model.bin` file. Error: {}", e))); + } + }, + }, + }; + + let config = std::fs::read_to_string(config_filename)?; + let config: Config = serde_json::from_str(&config)?; + let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + + let device = candle_examples::device(args.cpu)?; + + let vb = if weights_filename.ends_with("model.safetensors") { + unsafe { + VarBuilder::from_mmaped_safetensors(&[weights_filename], candle::DType::F16, &device) + .unwrap() + } + } else { + println!("Loading weights from pytorch_model.bin"); + VarBuilder::from_pth(&weights_filename, candle::DType::F16, &device).unwrap() + }; + tokenizer + .with_padding(Some(PaddingParams { + strategy: tokenizers::PaddingStrategy::BatchLongest, + pad_id: config.pad_token_id, + ..Default::default() + })) + .with_truncation(None) + .map_err(E::msg)?; + + match args.task { + Task::FillMask => { + let prompt = vec![ + "Hello I'm a <mask> model.".to_string(), + "I'm a <mask> boy.".to_string(), + "I'm <mask> in berlin.".to_string(), + ]; + let model = XLMRobertaForMaskedLM::new(&config, vb)?; + + let input_ids = tokenize_batch(&tokenizer, TokenizeInput::Single(&prompt), &device)?; + let attention_mask = + get_attention_mask(&tokenizer, TokenizeInput::Single(&prompt), &device)?; + + let token_type_ids = Tensor::zeros(input_ids.dims(), input_ids.dtype(), &device)?; + + let output = model + .forward( + &input_ids, + &attention_mask, + &token_type_ids, + None, + None, + None, + )? + .to_dtype(candle::DType::F32)?; + + let max_outs = output.argmax(2)?; + + let max_out = max_outs.to_vec2::<u32>()?; + let max_out_refs: Vec<&[u32]> = max_out.iter().map(|v| v.as_slice()).collect(); + let decoded = tokenizer.decode_batch(&max_out_refs, true).unwrap(); + for (i, sentence) in decoded.iter().enumerate() { + println!("Sentence: {} : {}", i + 1, sentence); + } + } + Task::Reranker => { + let query = "what is panda?".to_string(); + + let documents = ["South Korea is a country in East Asia.".to_string(), + "There are forests in the mountains.".to_string(), + "Pandas look like bears.".to_string(), + "There are some animals with black and white fur.".to_string(), + "The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.".to_string()]; + + // create pairs of query and documents + let pairs = documents + .iter() + .map(|doc| (query.clone(), doc.clone())) + .collect::<Vec<_>>(); + let input_ids = tokenize_batch(&tokenizer, TokenizeInput::Pairs(&pairs), &device)?; + let attention_mask = + get_attention_mask(&tokenizer, TokenizeInput::Pairs(&pairs), &device)?; + let token_type_ids = Tensor::zeros(input_ids.dims(), input_ids.dtype(), &device)?; + + let model = XLMRobertaForSequenceClassification::new(1, &config, vb)?; + + let output = model.forward(&input_ids, &attention_mask, &token_type_ids)?; + let output = candle_nn::ops::sigmoid(&output)?.t().unwrap(); + let ranks = output + .arg_sort_last_dim(false)? + .to_vec2::<u32>()? + .into_iter() + .flatten() + .collect::<Vec<_>>(); + println!("\nRanking Results:"); + println!("{:-<80}", ""); + documents.iter().enumerate().for_each(|(idx, doc)| { + let rank = ranks.iter().position(|&r| r == idx as u32).unwrap(); + let score = output + .get_on_dim(1, idx) + .unwrap() + .to_dtype(candle::DType::F32) + .unwrap() + .to_vec1::<f32>() + .unwrap(); + println!("Rank #{:<2} | Score: {:.4} | {}", rank + 1, score[0], doc); + }); + println!("{:-<80}", ""); + } + } + Ok(()) +} + +#[derive(Debug)] +pub enum TokenizeInput<'a> { + Single(&'a [String]), + Pairs(&'a [(String, String)]), +} + +pub fn tokenize_batch( + tokenizer: &Tokenizer, + input: TokenizeInput, + device: &Device, +) -> anyhow::Result<Tensor> { + let tokens = match input { + TokenizeInput::Single(text_batch) => tokenizer + .encode_batch(text_batch.to_vec(), true) + .map_err(E::msg)?, + TokenizeInput::Pairs(pairs) => tokenizer + .encode_batch(pairs.to_vec(), true) + .map_err(E::msg)?, + }; + + let token_ids = tokens + .iter() + .map(|tokens| { + let tokens = tokens.get_ids().to_vec(); + Tensor::new(tokens.as_slice(), device) + }) + .collect::<candle::Result<Vec<_>>>()?; + + Ok(Tensor::stack(&token_ids, 0)?) +} + +pub fn get_attention_mask( + tokenizer: &Tokenizer, + input: TokenizeInput, + device: &Device, +) -> anyhow::Result<Tensor> { + let tokens = match input { + TokenizeInput::Single(text_batch) => tokenizer + .encode_batch(text_batch.to_vec(), true) + .map_err(E::msg)?, + TokenizeInput::Pairs(pairs) => tokenizer + .encode_batch(pairs.to_vec(), true) + .map_err(E::msg)?, + }; + + let attention_mask = tokens + .iter() + .map(|tokens| { + let tokens = tokens.get_attention_mask().to_vec(); + Tensor::new(tokens.as_slice(), device) + }) + .collect::<candle::Result<Vec<_>>>()?; + Ok(Tensor::stack(&attention_mask, 0)?) +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index be1f15c4..5f566991 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -109,4 +109,5 @@ pub mod vit; pub mod whisper; pub mod with_tracing; pub mod wuerstchen; +pub mod xlm_roberta; pub mod yi; diff --git a/candle-transformers/src/models/xlm_roberta.rs b/candle-transformers/src/models/xlm_roberta.rs new file mode 100644 index 00000000..96e763e1 --- /dev/null +++ b/candle-transformers/src/models/xlm_roberta.rs @@ -0,0 +1,545 @@ +use crate::models::with_tracing::{linear, Linear}; +use candle::{DType, Module, Result, Tensor}; +use candle_nn::{ + embedding, layer_norm, ops::softmax_last_dim, Activation, Embedding, LayerNorm, VarBuilder, +}; + +#[derive(Debug, Clone, serde::Deserialize)] +pub struct Config { + pub hidden_size: usize, + pub layer_norm_eps: f64, + pub attention_probs_dropout_prob: f32, + pub hidden_dropout_prob: f32, + pub num_attention_heads: usize, + pub position_embedding_type: String, + pub intermediate_size: usize, + pub hidden_act: Activation, + pub num_hidden_layers: usize, + pub vocab_size: usize, + pub max_position_embeddings: usize, + pub type_vocab_size: usize, + pub pad_token_id: u32, +} + +struct XLMRobertaEmbeddings { + word_embeddings: Embedding, + position_embeddings: Option<Embedding>, + token_type_embeddings: Embedding, + layer_norm: LayerNorm, + padding_idx: u32, + span: tracing::Span, +} + +impl XLMRobertaEmbeddings { + fn load(vb: VarBuilder, config: &Config) -> Result<Self> { + let word_embeddings = embedding( + config.vocab_size, + config.hidden_size, + vb.pp("word_embeddings"), + )?; + let position_embeddings = embedding( + config.max_position_embeddings, + config.hidden_size, + vb.pp("position_embeddings"), + )?; + let token_type_embeddings = embedding( + config.type_vocab_size, + config.hidden_size, + vb.pp("token_type_embeddings"), + )?; + let layer_norm = layer_norm( + config.hidden_size, + config.layer_norm_eps, + vb.pp("LayerNorm"), + )?; + Ok(Self { + word_embeddings, + position_embeddings: Some(position_embeddings), + token_type_embeddings, + layer_norm, + padding_idx: config.pad_token_id, + span: tracing::span!(tracing::Level::TRACE, "embeddings"), + }) + } + + fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let (_bsize, _) = input_ids.dims2()?; + let input_embeddings = self.word_embeddings.forward(input_ids)?; + let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?; + let mut embeddings = (&input_embeddings + token_type_embeddings)?; + if let Some(position_embeddings) = &self.position_embeddings { + let mask = input_ids + .ne(self.padding_idx)? + .to_dtype(input_embeddings.dtype())?; + let cumsum = mask.cumsum(1)?; + let position_ids = (cumsum * mask)? + .broadcast_add( + &Tensor::try_from(self.padding_idx)? + .to_dtype(input_embeddings.dtype())? + .to_device(input_embeddings.device())?, + )? + .to_dtype(candle::DType::U32)?; + embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)?; + } + let embeddings = self.layer_norm.forward(&embeddings)?; + Ok(embeddings) + } +} + +struct XLMRobertaSelfAttention { + num_attention_heads: usize, + attention_head_size: usize, + all_head_size: usize, + query: Linear, + key: Linear, + value: Linear, +} + +impl XLMRobertaSelfAttention { + fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let attention_head_size = cfg.hidden_size / cfg.num_attention_heads; + let all_head_size = cfg.num_attention_heads * attention_head_size; + Ok(Self { + num_attention_heads: cfg.num_attention_heads, + attention_head_size, + all_head_size, + query: linear(cfg.hidden_size, all_head_size, vb.pp("query"))?, + key: linear(cfg.hidden_size, all_head_size, vb.pp("key"))?, + value: linear(cfg.hidden_size, all_head_size, vb.pp("value"))?, + }) + } + + fn transpose_for_scores(&self, x: &Tensor) -> Result<Tensor> { + let mut new_x_shape = x.dims().to_vec(); + new_x_shape[2] = self.num_attention_heads; + new_x_shape.push(self.attention_head_size); + let x = x.reshape(new_x_shape)?; + x.permute((0, 2, 1, 3))?.contiguous() + } + + fn forward( + &self, + hidden_states: &Tensor, + encoder_hidden_states: Option<&Tensor>, + attention_mask: &Tensor, + past_key_value: Option<(&Tensor, &Tensor)>, + encoder_attention_mask: Option<&Tensor>, + ) -> Result<Tensor> { + let mixed_query_layer = self.query.forward(hidden_states)?; + let is_cross_attention = encoder_hidden_states.is_some(); + let (key_layer, value_layer, attention_mask) = if is_cross_attention + && past_key_value.is_some() + { + let key_layer = past_key_value.unwrap().0.clone(); + let value_layer = past_key_value.unwrap().1.clone(); + let attention_mask = encoder_attention_mask.unwrap().clone(); + (key_layer, value_layer, Some(attention_mask)) + } else if is_cross_attention { + let key_layer = + self.transpose_for_scores(&self.key.forward(encoder_hidden_states.unwrap())?)?; + let value_layer = + self.transpose_for_scores(&self.value.forward(encoder_hidden_states.unwrap())?)?; + let attention_mask = encoder_attention_mask.unwrap(); + (key_layer, value_layer, Some(attention_mask.clone())) + } else if past_key_value.is_some() { + let mut key_layer = self.transpose_for_scores(&self.key.forward(hidden_states)?)?; + let mut value_layer = self.transpose_for_scores(&self.value.forward(hidden_states)?)?; + key_layer = Tensor::cat( + &[ + past_key_value.clone().as_ref().unwrap().0.clone(), + key_layer, + ], + 2, + )?; + value_layer = Tensor::cat( + &[past_key_value.as_ref().unwrap().1.clone(), value_layer], + 2, + )?; + (key_layer, value_layer, Some(attention_mask.clone())) + } else { + let key_layer = self.transpose_for_scores(&self.key.forward(hidden_states)?)?; + let value_layer = self.transpose_for_scores(&self.value.forward(hidden_states)?)?; + (key_layer, value_layer, Some(attention_mask.clone())) + }; + + let query_layer = self.transpose_for_scores(&mixed_query_layer)?; + let mut attention_scores = query_layer.matmul(&key_layer.transpose(2, 3)?)?; + let scale = 1f64 / f64::sqrt(self.attention_head_size as f64); + + attention_scores = (attention_scores * scale)?; + attention_scores = match attention_mask { + None => attention_scores, + Some(mask) => { + attention_scores.broadcast_add(&mask.to_dtype(attention_scores.dtype())?)? + } + }; + let attention_probs = softmax_last_dim(&attention_scores)?; + + let context_layer = attention_probs + .matmul(&value_layer)? + .permute((0, 2, 1, 3))? + .contiguous()?; + let mut new_context_layer_shape = + context_layer.dims()[..context_layer.dims().len() - 2].to_vec(); + new_context_layer_shape.push(self.all_head_size); + let context_layer = context_layer.reshape(new_context_layer_shape)?; + + Ok(context_layer) + } +} + +struct XLMRobertaSelfOutput { + dense: Linear, + layernorm: LayerNorm, +} + +impl XLMRobertaSelfOutput { + fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?; + let layernorm = + candle_nn::layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?; + Ok(Self { dense, layernorm }) + } + + fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> { + let hidden_states = self.dense.forward(hidden_states)?; + let hidden_states = self.layernorm.forward(&(hidden_states + input_tensor)?)?; + Ok(hidden_states) + } +} + +struct XLMRobertaAttention { + output: XLMRobertaSelfOutput, + self_attention: XLMRobertaSelfAttention, +} + +impl XLMRobertaAttention { + fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let output = XLMRobertaSelfOutput::new(cfg, vb.pp("output"))?; + let self_attention = XLMRobertaSelfAttention::new(cfg, vb.pp("self"))?; + Ok(Self { + output, + self_attention, + }) + } + + fn forward( + &self, + hidden_states: &Tensor, + attention_mask: &Tensor, + encoder_hidden_states: Option<&Tensor>, + encoder_attention_mask: Option<&Tensor>, + past_key_value: Option<(&Tensor, &Tensor)>, + ) -> Result<(Tensor, Tensor)> { + let self_outputs = self.self_attention.forward( + hidden_states, + encoder_hidden_states, + attention_mask, + past_key_value, + encoder_attention_mask, + )?; + let attention_output = self.output.forward(&self_outputs, hidden_states)?; + Ok((attention_output, self_outputs)) + } +} + +struct XLMRobertaOutput { + dense: Linear, + layernorm: LayerNorm, +} + +impl XLMRobertaOutput { + fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let dense = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("dense"))?; + let layernorm = + candle_nn::layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?; + Ok(Self { dense, layernorm }) + } + + fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> { + let hidden_states = self.dense.forward(hidden_states)?; + let hidden_states = self.layernorm.forward(&(hidden_states + input_tensor)?)?; + Ok(hidden_states) + } +} + +struct XLMRobertaIntermediate { + dense: Linear, + intermediate_act_fn: Activation, +} + +impl XLMRobertaIntermediate { + fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let dense = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("dense"))?; + let intermediate_act_fn = cfg.hidden_act; + Ok(Self { + dense, + intermediate_act_fn, + }) + } + + fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> { + let hidden_states = self.dense.forward(hidden_states)?; + let hidden_states = self.intermediate_act_fn.forward(&hidden_states)?; + Ok(hidden_states) + } +} + +struct XLMRobertaLayer { + attention: XLMRobertaAttention, + intermediate: XLMRobertaIntermediate, + output: XLMRobertaOutput, +} + +impl XLMRobertaLayer { + fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let attention = XLMRobertaAttention::new(cfg, vb.pp("attention"))?; + let intermediate = XLMRobertaIntermediate::new(cfg, vb.pp("intermediate"))?; + let output = XLMRobertaOutput::new(cfg, vb.pp("output"))?; + Ok(Self { + attention, + intermediate, + output, + }) + } + + fn forward( + &self, + hidden_states: &Tensor, + attention_mask: &Tensor, + encoder_hidden_states: Option<&Tensor>, + encoder_attention_mask: Option<&Tensor>, + past_key_value: Option<(&Tensor, &Tensor)>, + ) -> Result<(Tensor, Tensor)> { + let self_attention_outputs = self.attention.forward( + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + )?; + let attention_output = self_attention_outputs.0; + let outputs = self_attention_outputs.1; + let intermediate_output = self.intermediate.forward(&attention_output)?; + let layer_output = self + .output + .forward(&intermediate_output, &attention_output)?; + Ok((layer_output, outputs)) + } +} + +struct XLMRobertaEncoder { + layers: Vec<XLMRobertaLayer>, +} + +impl XLMRobertaEncoder { + fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let layers = (0..cfg.num_hidden_layers) + .map(|i| XLMRobertaLayer::new(cfg, vb.pp(format!("layer.{}", i)))) + .collect::<Result<Vec<_>>>()?; + Ok(Self { layers }) + } + + fn forward( + &self, + hidden_states: &Tensor, + attention_mask: &Tensor, + encoder_hidden_states: Option<&Tensor>, + encoder_attention_mask: Option<&Tensor>, + past_key_value: Option<(&Tensor, &Tensor)>, + ) -> Result<Tensor> { + let mut hidden_states = hidden_states.clone(); + for layer_module in self.layers.iter() { + let layer_outputs = layer_module.forward( + &hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + )?; + hidden_states = layer_outputs.0; + } + Ok(hidden_states) + } +} + +pub struct XLMRobertaModel { + encoder: XLMRobertaEncoder, + embeddings: XLMRobertaEmbeddings, +} + +impl XLMRobertaModel { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let encoder = XLMRobertaEncoder::new(cfg, vb.pp("encoder"))?; + let embeddings = XLMRobertaEmbeddings::load(vb.pp("embeddings"), cfg)?; + Ok(Self { + encoder, + embeddings, + }) + } + + pub fn forward( + &self, + input_ids: &Tensor, + attention_mask: &Tensor, + token_type_ids: &Tensor, + past_key_value: Option<(&Tensor, &Tensor)>, + encoder_hidden_states: Option<&Tensor>, + encoder_attention_mask: Option<&Tensor>, + ) -> Result<Tensor> { + let hidden_states = self.embeddings.forward(input_ids, token_type_ids)?; + let attention_mask = prepare_4d_attention_mask(attention_mask, DType::F32, None)? + .to_device(hidden_states.device())?; + let hidden_states = self.encoder.forward( + &hidden_states, + &attention_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + )?; + Ok(hidden_states) + } +} + +struct XLMRobertaLMHead { + dense: Linear, + layer_norm: LayerNorm, +} + +impl XLMRobertaLMHead { + fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?; + let layer_norm = + candle_nn::layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("layer_norm"))?; + Ok(Self { dense, layer_norm }) + } + + fn forward(&self, hidden_states: &Tensor, shared_embeddings: &Tensor) -> Result<Tensor> { + let hidden_states = self.dense.forward(hidden_states)?; + let hidden_states = candle_nn::Activation::Gelu.forward(&hidden_states)?; + let hidden_states = self.layer_norm.forward(&hidden_states)?; + let hidden_states = hidden_states.broadcast_matmul(shared_embeddings)?; + Ok(hidden_states) + } +} + +pub struct XLMRobertaForMaskedLM { + roberta: XLMRobertaModel, + lm_head: XLMRobertaLMHead, +} + +impl XLMRobertaForMaskedLM { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let roberta = XLMRobertaModel::new(cfg, vb.pp("roberta"))?; + let lm_head = XLMRobertaLMHead::new(cfg, vb.pp("lm_head"))?; + Ok(Self { roberta, lm_head }) + } + + pub fn forward( + &self, + input_ids: &Tensor, + attention_mask: &Tensor, + token_type_ids: &Tensor, + past_key_value: Option<(&Tensor, &Tensor)>, + encoder_hidden_states: Option<&Tensor>, + encoder_attention_mask: Option<&Tensor>, + ) -> Result<Tensor> { + let hidden_states = self.roberta.forward( + input_ids, + attention_mask, + token_type_ids, + past_key_value, + encoder_hidden_states, + encoder_attention_mask, + )?; + let lm_logits = self.lm_head.forward( + &hidden_states, + &self + .roberta + .embeddings + .word_embeddings + .embeddings() + .t()? + .unsqueeze(0)?, + )?; + Ok(lm_logits) + } +} + +struct XLMRobertaClassificationHead { + dense: Linear, + out_proj: Linear, +} + +impl XLMRobertaClassificationHead { + fn new(num_labels: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> { + let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?; + let out_proj = linear(cfg.hidden_size, num_labels, vb.pp("out_proj"))?; + Ok(Self { dense, out_proj }) + } + + fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> { + let cls_states = hidden_states.get_on_dim(1, 0)?.contiguous()?; + let hidden_states = self.dense.forward(&cls_states)?; + let hidden_states = candle_nn::Activation::GeluPytorchTanh.forward(&hidden_states)?; + let hidden_states = self.out_proj.forward(&hidden_states)?; + Ok(hidden_states) + } +} + +pub struct XLMRobertaForSequenceClassification { + roberta: XLMRobertaModel, + classifier: XLMRobertaClassificationHead, +} + +impl XLMRobertaForSequenceClassification { + pub fn new(num_labels: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> { + let roberta = XLMRobertaModel::new(cfg, vb.pp("roberta"))?; + let classifier = XLMRobertaClassificationHead::new(num_labels, cfg, vb.pp("classifier"))?; + Ok(Self { + roberta, + classifier, + }) + } + + pub fn forward( + &self, + input_ids: &Tensor, + attention_mask: &Tensor, + token_type_ids: &Tensor, + ) -> Result<Tensor> { + let hidden_states = + self.roberta + .forward(input_ids, attention_mask, token_type_ids, None, None, None)?; + self.classifier.forward(&hidden_states) + } +} + +fn prepare_4d_attention_mask( + mask: &Tensor, + dtype: DType, + tgt_len: Option<usize>, +) -> Result<Tensor> { + let bsz = mask.dim(0)?; + let src_len = mask.dim(1)?; + let tgt_len = tgt_len.unwrap_or(src_len); + + let expanded_mask = mask + .unsqueeze(1)? + .unsqueeze(2)? + .expand((bsz, 1, tgt_len, src_len))? + .to_dtype(dtype)?; + + let inverted_mask = (1.0 - expanded_mask)?; + + (inverted_mask * get_dtype_min_val(dtype))?.to_dtype(dtype) +} + +fn get_dtype_min_val(dtype: DType) -> f64 { + match dtype { + DType::F32 => f32::MIN as f64, + DType::F64 => f64::MIN, + _ => panic!("Unsupported data type"), + } +} |