summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAkshay Ballal <61191840+akshayballal95@users.noreply.github.com>2024-12-30 11:16:57 +0100
committerGitHub <noreply@github.com>2024-12-30 11:16:57 +0100
commit91f1f019b13386f4df3e9b2826c982d10bcc497e (patch)
tree1acf1a2921592bbd62ac43ea1321376164b8724b
parentcd639131f04990c16bfc498ea347cb9df3d2374f (diff)
downloadcandle-91f1f019b13386f4df3e9b2826c982d10bcc497e.tar.gz
candle-91f1f019b13386f4df3e9b2826c982d10bcc497e.tar.bz2
candle-91f1f019b13386f4df3e9b2826c982d10bcc497e.zip
Added XLMRobertaModel for Reranking (#2686)
* add xlm-roberta-base * Add task enum for fill-mask and reranker in xlm-roberta example; update README and fix attention mask dimensions - Introduced a new `Task` enum to replace string task identifiers in the xlm-roberta example. - Updated the logic in `main.rs` to handle tasks using the new enum. - Enhanced README with example output for fill-mask task. - Fixed dimension retrieval in `prepare_4d_attention_mask` function for better clarity and safety. * Clippy fix. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
-rw-r--r--candle-examples/examples/xlm-roberta/Readme.md30
-rw-r--r--candle-examples/examples/xlm-roberta/main.rs277
-rw-r--r--candle-transformers/src/models/mod.rs1
-rw-r--r--candle-transformers/src/models/xlm_roberta.rs545
4 files changed, 853 insertions, 0 deletions
diff --git a/candle-examples/examples/xlm-roberta/Readme.md b/candle-examples/examples/xlm-roberta/Readme.md
new file mode 100644
index 00000000..496b14e3
--- /dev/null
+++ b/candle-examples/examples/xlm-roberta/Readme.md
@@ -0,0 +1,30 @@
+# candle-xlm-roberta
+
+This example demonstrates how to use the XLM-RoBERTa model in Candle especially known for their use in reranking. It uses the `fill-mask` task to generate a word for a masked token. And a `reranker` task to rerank a list of documents for a given query.
+
+## Usage
+
+Fill Mask:
+```bash
+cargo run --example xlm-roberta --release -- --task fill-mask --model xlm-roberta-base
+```
+```markdown
+Sentence: 0 : Hello I'm a fashion model.
+Sentence: 1 : I'm a little boy.
+Sentence: 2 : I'm living in berlin.
+```
+
+Reranker:
+```bash
+cargo run --example xlm-roberta --release -- --task reranker --model bge-reranker-base
+```
+```markdown
+Ranking Results:
+--------------------------------------------------------------------------------
+> Rank #4 | Score: 0.0001 | South Korea is a country in East Asia.
+> Rank #5 | Score: 0.0000 | There are forests in the mountains.
+> Rank #2 | Score: 0.7314 | Pandas look like bears.
+> Rank #3 | Score: 0.6948 | There are some animals with black and white fur.
+> Rank #1 | Score: 0.9990 | The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.
+--------------------------------------------------------------------------------
+```
diff --git a/candle-examples/examples/xlm-roberta/main.rs b/candle-examples/examples/xlm-roberta/main.rs
new file mode 100644
index 00000000..47ab44b0
--- /dev/null
+++ b/candle-examples/examples/xlm-roberta/main.rs
@@ -0,0 +1,277 @@
+use std::path::PathBuf;
+
+use anyhow::{Error as E, Result};
+use candle::{Device, Tensor};
+use candle_nn::VarBuilder;
+use candle_transformers::models::xlm_roberta::{
+ Config, XLMRobertaForMaskedLM, XLMRobertaForSequenceClassification,
+};
+use clap::{Parser, ValueEnum};
+use hf_hub::{api::sync::Api, Repo, RepoType};
+use tokenizers::{PaddingParams, Tokenizer};
+
+#[derive(Debug, Clone, ValueEnum)]
+enum Model {
+ BgeRerankerBase,
+ BgeRerankerLarge,
+ BgeRerankerBaseV2,
+ XLMRobertaBase,
+ XLMRobertaLarge,
+}
+
+#[derive(Debug, Clone, ValueEnum)]
+enum Task {
+ FillMask,
+ Reranker,
+}
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+
+ /// Enable tracing (generates a trace-timestamp.json file).
+ #[arg(long)]
+ tracing: bool,
+
+ /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
+ #[arg(long)]
+ model_id: Option<String>,
+
+ #[arg(long, default_value = "main")]
+ revision: String,
+
+ #[arg(long, default_value = "bge-reranker-base")]
+ model: Model,
+
+ #[arg(long, default_value = "reranker")]
+ task: Task,
+
+ // Path to the tokenizer file.
+ #[arg(long)]
+ tokenizer_file: Option<String>,
+
+ // Path to the weight files.
+ #[arg(long)]
+ weight_files: Option<String>,
+
+ // Path to the config file.
+ #[arg(long)]
+ config_file: Option<String>,
+
+ /// When set, compute embeddings for this prompt.
+ #[arg(long)]
+ prompt: Option<String>,
+}
+
+fn main() -> Result<()> {
+ let args = Args::parse();
+ let api = Api::new()?;
+ let model_id = match &args.model_id {
+ Some(model_id) => model_id.to_string(),
+ None => match args.task {
+ Task::FillMask => match args.model {
+ Model::XLMRobertaBase => "FacebookAI/xlm-roberta-base".to_string(),
+ Model::XLMRobertaLarge => "FacebookAI/xlm-roberta-large".to_string(),
+ _ => anyhow::bail!("BGE models are not supported for fill-mask task"),
+ },
+ Task::Reranker => match args.model {
+ Model::BgeRerankerBase => "BAAI/bge-reranker-base".to_string(),
+ Model::BgeRerankerLarge => "BAAI/bge-reranker-large".to_string(),
+ Model::BgeRerankerBaseV2 => "BAAI/bge-reranker-base-v2-m3".to_string(),
+ _ => anyhow::bail!("XLM-RoBERTa models are not supported for reranker task"),
+ },
+ },
+ };
+ let repo = api.repo(Repo::with_revision(
+ model_id,
+ RepoType::Model,
+ args.revision,
+ ));
+
+ let tokenizer_filename = match args.tokenizer_file {
+ Some(file) => std::path::PathBuf::from(file),
+ None => repo.get("tokenizer.json")?,
+ };
+
+ let config_filename = match args.config_file {
+ Some(file) => std::path::PathBuf::from(file),
+ None => repo.get("config.json")?,
+ };
+
+ let weights_filename = match args.weight_files {
+ Some(files) => PathBuf::from(files),
+ None => match repo.get("model.safetensors") {
+ Ok(safetensors) => safetensors,
+ Err(_) => match repo.get("pytorch_model.bin") {
+ Ok(pytorch_model) => pytorch_model,
+ Err(e) => {
+ return Err(anyhow::Error::msg(format!("Model weights not found. The weights should either be a `model.safetensors` or `pytorch_model.bin` file. Error: {}", e)));
+ }
+ },
+ },
+ };
+
+ let config = std::fs::read_to_string(config_filename)?;
+ let config: Config = serde_json::from_str(&config)?;
+ let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
+
+ let device = candle_examples::device(args.cpu)?;
+
+ let vb = if weights_filename.ends_with("model.safetensors") {
+ unsafe {
+ VarBuilder::from_mmaped_safetensors(&[weights_filename], candle::DType::F16, &device)
+ .unwrap()
+ }
+ } else {
+ println!("Loading weights from pytorch_model.bin");
+ VarBuilder::from_pth(&weights_filename, candle::DType::F16, &device).unwrap()
+ };
+ tokenizer
+ .with_padding(Some(PaddingParams {
+ strategy: tokenizers::PaddingStrategy::BatchLongest,
+ pad_id: config.pad_token_id,
+ ..Default::default()
+ }))
+ .with_truncation(None)
+ .map_err(E::msg)?;
+
+ match args.task {
+ Task::FillMask => {
+ let prompt = vec![
+ "Hello I'm a <mask> model.".to_string(),
+ "I'm a <mask> boy.".to_string(),
+ "I'm <mask> in berlin.".to_string(),
+ ];
+ let model = XLMRobertaForMaskedLM::new(&config, vb)?;
+
+ let input_ids = tokenize_batch(&tokenizer, TokenizeInput::Single(&prompt), &device)?;
+ let attention_mask =
+ get_attention_mask(&tokenizer, TokenizeInput::Single(&prompt), &device)?;
+
+ let token_type_ids = Tensor::zeros(input_ids.dims(), input_ids.dtype(), &device)?;
+
+ let output = model
+ .forward(
+ &input_ids,
+ &attention_mask,
+ &token_type_ids,
+ None,
+ None,
+ None,
+ )?
+ .to_dtype(candle::DType::F32)?;
+
+ let max_outs = output.argmax(2)?;
+
+ let max_out = max_outs.to_vec2::<u32>()?;
+ let max_out_refs: Vec<&[u32]> = max_out.iter().map(|v| v.as_slice()).collect();
+ let decoded = tokenizer.decode_batch(&max_out_refs, true).unwrap();
+ for (i, sentence) in decoded.iter().enumerate() {
+ println!("Sentence: {} : {}", i + 1, sentence);
+ }
+ }
+ Task::Reranker => {
+ let query = "what is panda?".to_string();
+
+ let documents = ["South Korea is a country in East Asia.".to_string(),
+ "There are forests in the mountains.".to_string(),
+ "Pandas look like bears.".to_string(),
+ "There are some animals with black and white fur.".to_string(),
+ "The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.".to_string()];
+
+ // create pairs of query and documents
+ let pairs = documents
+ .iter()
+ .map(|doc| (query.clone(), doc.clone()))
+ .collect::<Vec<_>>();
+ let input_ids = tokenize_batch(&tokenizer, TokenizeInput::Pairs(&pairs), &device)?;
+ let attention_mask =
+ get_attention_mask(&tokenizer, TokenizeInput::Pairs(&pairs), &device)?;
+ let token_type_ids = Tensor::zeros(input_ids.dims(), input_ids.dtype(), &device)?;
+
+ let model = XLMRobertaForSequenceClassification::new(1, &config, vb)?;
+
+ let output = model.forward(&input_ids, &attention_mask, &token_type_ids)?;
+ let output = candle_nn::ops::sigmoid(&output)?.t().unwrap();
+ let ranks = output
+ .arg_sort_last_dim(false)?
+ .to_vec2::<u32>()?
+ .into_iter()
+ .flatten()
+ .collect::<Vec<_>>();
+ println!("\nRanking Results:");
+ println!("{:-<80}", "");
+ documents.iter().enumerate().for_each(|(idx, doc)| {
+ let rank = ranks.iter().position(|&r| r == idx as u32).unwrap();
+ let score = output
+ .get_on_dim(1, idx)
+ .unwrap()
+ .to_dtype(candle::DType::F32)
+ .unwrap()
+ .to_vec1::<f32>()
+ .unwrap();
+ println!("Rank #{:<2} | Score: {:.4} | {}", rank + 1, score[0], doc);
+ });
+ println!("{:-<80}", "");
+ }
+ }
+ Ok(())
+}
+
+#[derive(Debug)]
+pub enum TokenizeInput<'a> {
+ Single(&'a [String]),
+ Pairs(&'a [(String, String)]),
+}
+
+pub fn tokenize_batch(
+ tokenizer: &Tokenizer,
+ input: TokenizeInput,
+ device: &Device,
+) -> anyhow::Result<Tensor> {
+ let tokens = match input {
+ TokenizeInput::Single(text_batch) => tokenizer
+ .encode_batch(text_batch.to_vec(), true)
+ .map_err(E::msg)?,
+ TokenizeInput::Pairs(pairs) => tokenizer
+ .encode_batch(pairs.to_vec(), true)
+ .map_err(E::msg)?,
+ };
+
+ let token_ids = tokens
+ .iter()
+ .map(|tokens| {
+ let tokens = tokens.get_ids().to_vec();
+ Tensor::new(tokens.as_slice(), device)
+ })
+ .collect::<candle::Result<Vec<_>>>()?;
+
+ Ok(Tensor::stack(&token_ids, 0)?)
+}
+
+pub fn get_attention_mask(
+ tokenizer: &Tokenizer,
+ input: TokenizeInput,
+ device: &Device,
+) -> anyhow::Result<Tensor> {
+ let tokens = match input {
+ TokenizeInput::Single(text_batch) => tokenizer
+ .encode_batch(text_batch.to_vec(), true)
+ .map_err(E::msg)?,
+ TokenizeInput::Pairs(pairs) => tokenizer
+ .encode_batch(pairs.to_vec(), true)
+ .map_err(E::msg)?,
+ };
+
+ let attention_mask = tokens
+ .iter()
+ .map(|tokens| {
+ let tokens = tokens.get_attention_mask().to_vec();
+ Tensor::new(tokens.as_slice(), device)
+ })
+ .collect::<candle::Result<Vec<_>>>()?;
+ Ok(Tensor::stack(&attention_mask, 0)?)
+}
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index be1f15c4..5f566991 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -109,4 +109,5 @@ pub mod vit;
pub mod whisper;
pub mod with_tracing;
pub mod wuerstchen;
+pub mod xlm_roberta;
pub mod yi;
diff --git a/candle-transformers/src/models/xlm_roberta.rs b/candle-transformers/src/models/xlm_roberta.rs
new file mode 100644
index 00000000..96e763e1
--- /dev/null
+++ b/candle-transformers/src/models/xlm_roberta.rs
@@ -0,0 +1,545 @@
+use crate::models::with_tracing::{linear, Linear};
+use candle::{DType, Module, Result, Tensor};
+use candle_nn::{
+ embedding, layer_norm, ops::softmax_last_dim, Activation, Embedding, LayerNorm, VarBuilder,
+};
+
+#[derive(Debug, Clone, serde::Deserialize)]
+pub struct Config {
+ pub hidden_size: usize,
+ pub layer_norm_eps: f64,
+ pub attention_probs_dropout_prob: f32,
+ pub hidden_dropout_prob: f32,
+ pub num_attention_heads: usize,
+ pub position_embedding_type: String,
+ pub intermediate_size: usize,
+ pub hidden_act: Activation,
+ pub num_hidden_layers: usize,
+ pub vocab_size: usize,
+ pub max_position_embeddings: usize,
+ pub type_vocab_size: usize,
+ pub pad_token_id: u32,
+}
+
+struct XLMRobertaEmbeddings {
+ word_embeddings: Embedding,
+ position_embeddings: Option<Embedding>,
+ token_type_embeddings: Embedding,
+ layer_norm: LayerNorm,
+ padding_idx: u32,
+ span: tracing::Span,
+}
+
+impl XLMRobertaEmbeddings {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let word_embeddings = embedding(
+ config.vocab_size,
+ config.hidden_size,
+ vb.pp("word_embeddings"),
+ )?;
+ let position_embeddings = embedding(
+ config.max_position_embeddings,
+ config.hidden_size,
+ vb.pp("position_embeddings"),
+ )?;
+ let token_type_embeddings = embedding(
+ config.type_vocab_size,
+ config.hidden_size,
+ vb.pp("token_type_embeddings"),
+ )?;
+ let layer_norm = layer_norm(
+ config.hidden_size,
+ config.layer_norm_eps,
+ vb.pp("LayerNorm"),
+ )?;
+ Ok(Self {
+ word_embeddings,
+ position_embeddings: Some(position_embeddings),
+ token_type_embeddings,
+ layer_norm,
+ padding_idx: config.pad_token_id,
+ span: tracing::span!(tracing::Level::TRACE, "embeddings"),
+ })
+ }
+
+ fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let (_bsize, _) = input_ids.dims2()?;
+ let input_embeddings = self.word_embeddings.forward(input_ids)?;
+ let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
+ let mut embeddings = (&input_embeddings + token_type_embeddings)?;
+ if let Some(position_embeddings) = &self.position_embeddings {
+ let mask = input_ids
+ .ne(self.padding_idx)?
+ .to_dtype(input_embeddings.dtype())?;
+ let cumsum = mask.cumsum(1)?;
+ let position_ids = (cumsum * mask)?
+ .broadcast_add(
+ &Tensor::try_from(self.padding_idx)?
+ .to_dtype(input_embeddings.dtype())?
+ .to_device(input_embeddings.device())?,
+ )?
+ .to_dtype(candle::DType::U32)?;
+ embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)?;
+ }
+ let embeddings = self.layer_norm.forward(&embeddings)?;
+ Ok(embeddings)
+ }
+}
+
+struct XLMRobertaSelfAttention {
+ num_attention_heads: usize,
+ attention_head_size: usize,
+ all_head_size: usize,
+ query: Linear,
+ key: Linear,
+ value: Linear,
+}
+
+impl XLMRobertaSelfAttention {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let attention_head_size = cfg.hidden_size / cfg.num_attention_heads;
+ let all_head_size = cfg.num_attention_heads * attention_head_size;
+ Ok(Self {
+ num_attention_heads: cfg.num_attention_heads,
+ attention_head_size,
+ all_head_size,
+ query: linear(cfg.hidden_size, all_head_size, vb.pp("query"))?,
+ key: linear(cfg.hidden_size, all_head_size, vb.pp("key"))?,
+ value: linear(cfg.hidden_size, all_head_size, vb.pp("value"))?,
+ })
+ }
+
+ fn transpose_for_scores(&self, x: &Tensor) -> Result<Tensor> {
+ let mut new_x_shape = x.dims().to_vec();
+ new_x_shape[2] = self.num_attention_heads;
+ new_x_shape.push(self.attention_head_size);
+ let x = x.reshape(new_x_shape)?;
+ x.permute((0, 2, 1, 3))?.contiguous()
+ }
+
+ fn forward(
+ &self,
+ hidden_states: &Tensor,
+ encoder_hidden_states: Option<&Tensor>,
+ attention_mask: &Tensor,
+ past_key_value: Option<(&Tensor, &Tensor)>,
+ encoder_attention_mask: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let mixed_query_layer = self.query.forward(hidden_states)?;
+ let is_cross_attention = encoder_hidden_states.is_some();
+ let (key_layer, value_layer, attention_mask) = if is_cross_attention
+ && past_key_value.is_some()
+ {
+ let key_layer = past_key_value.unwrap().0.clone();
+ let value_layer = past_key_value.unwrap().1.clone();
+ let attention_mask = encoder_attention_mask.unwrap().clone();
+ (key_layer, value_layer, Some(attention_mask))
+ } else if is_cross_attention {
+ let key_layer =
+ self.transpose_for_scores(&self.key.forward(encoder_hidden_states.unwrap())?)?;
+ let value_layer =
+ self.transpose_for_scores(&self.value.forward(encoder_hidden_states.unwrap())?)?;
+ let attention_mask = encoder_attention_mask.unwrap();
+ (key_layer, value_layer, Some(attention_mask.clone()))
+ } else if past_key_value.is_some() {
+ let mut key_layer = self.transpose_for_scores(&self.key.forward(hidden_states)?)?;
+ let mut value_layer = self.transpose_for_scores(&self.value.forward(hidden_states)?)?;
+ key_layer = Tensor::cat(
+ &[
+ past_key_value.clone().as_ref().unwrap().0.clone(),
+ key_layer,
+ ],
+ 2,
+ )?;
+ value_layer = Tensor::cat(
+ &[past_key_value.as_ref().unwrap().1.clone(), value_layer],
+ 2,
+ )?;
+ (key_layer, value_layer, Some(attention_mask.clone()))
+ } else {
+ let key_layer = self.transpose_for_scores(&self.key.forward(hidden_states)?)?;
+ let value_layer = self.transpose_for_scores(&self.value.forward(hidden_states)?)?;
+ (key_layer, value_layer, Some(attention_mask.clone()))
+ };
+
+ let query_layer = self.transpose_for_scores(&mixed_query_layer)?;
+ let mut attention_scores = query_layer.matmul(&key_layer.transpose(2, 3)?)?;
+ let scale = 1f64 / f64::sqrt(self.attention_head_size as f64);
+
+ attention_scores = (attention_scores * scale)?;
+ attention_scores = match attention_mask {
+ None => attention_scores,
+ Some(mask) => {
+ attention_scores.broadcast_add(&mask.to_dtype(attention_scores.dtype())?)?
+ }
+ };
+ let attention_probs = softmax_last_dim(&attention_scores)?;
+
+ let context_layer = attention_probs
+ .matmul(&value_layer)?
+ .permute((0, 2, 1, 3))?
+ .contiguous()?;
+ let mut new_context_layer_shape =
+ context_layer.dims()[..context_layer.dims().len() - 2].to_vec();
+ new_context_layer_shape.push(self.all_head_size);
+ let context_layer = context_layer.reshape(new_context_layer_shape)?;
+
+ Ok(context_layer)
+ }
+}
+
+struct XLMRobertaSelfOutput {
+ dense: Linear,
+ layernorm: LayerNorm,
+}
+
+impl XLMRobertaSelfOutput {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
+ let layernorm =
+ candle_nn::layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
+ Ok(Self { dense, layernorm })
+ }
+
+ fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
+ let hidden_states = self.dense.forward(hidden_states)?;
+ let hidden_states = self.layernorm.forward(&(hidden_states + input_tensor)?)?;
+ Ok(hidden_states)
+ }
+}
+
+struct XLMRobertaAttention {
+ output: XLMRobertaSelfOutput,
+ self_attention: XLMRobertaSelfAttention,
+}
+
+impl XLMRobertaAttention {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let output = XLMRobertaSelfOutput::new(cfg, vb.pp("output"))?;
+ let self_attention = XLMRobertaSelfAttention::new(cfg, vb.pp("self"))?;
+ Ok(Self {
+ output,
+ self_attention,
+ })
+ }
+
+ fn forward(
+ &self,
+ hidden_states: &Tensor,
+ attention_mask: &Tensor,
+ encoder_hidden_states: Option<&Tensor>,
+ encoder_attention_mask: Option<&Tensor>,
+ past_key_value: Option<(&Tensor, &Tensor)>,
+ ) -> Result<(Tensor, Tensor)> {
+ let self_outputs = self.self_attention.forward(
+ hidden_states,
+ encoder_hidden_states,
+ attention_mask,
+ past_key_value,
+ encoder_attention_mask,
+ )?;
+ let attention_output = self.output.forward(&self_outputs, hidden_states)?;
+ Ok((attention_output, self_outputs))
+ }
+}
+
+struct XLMRobertaOutput {
+ dense: Linear,
+ layernorm: LayerNorm,
+}
+
+impl XLMRobertaOutput {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let dense = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("dense"))?;
+ let layernorm =
+ candle_nn::layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
+ Ok(Self { dense, layernorm })
+ }
+
+ fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
+ let hidden_states = self.dense.forward(hidden_states)?;
+ let hidden_states = self.layernorm.forward(&(hidden_states + input_tensor)?)?;
+ Ok(hidden_states)
+ }
+}
+
+struct XLMRobertaIntermediate {
+ dense: Linear,
+ intermediate_act_fn: Activation,
+}
+
+impl XLMRobertaIntermediate {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let dense = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("dense"))?;
+ let intermediate_act_fn = cfg.hidden_act;
+ Ok(Self {
+ dense,
+ intermediate_act_fn,
+ })
+ }
+
+ fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
+ let hidden_states = self.dense.forward(hidden_states)?;
+ let hidden_states = self.intermediate_act_fn.forward(&hidden_states)?;
+ Ok(hidden_states)
+ }
+}
+
+struct XLMRobertaLayer {
+ attention: XLMRobertaAttention,
+ intermediate: XLMRobertaIntermediate,
+ output: XLMRobertaOutput,
+}
+
+impl XLMRobertaLayer {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let attention = XLMRobertaAttention::new(cfg, vb.pp("attention"))?;
+ let intermediate = XLMRobertaIntermediate::new(cfg, vb.pp("intermediate"))?;
+ let output = XLMRobertaOutput::new(cfg, vb.pp("output"))?;
+ Ok(Self {
+ attention,
+ intermediate,
+ output,
+ })
+ }
+
+ fn forward(
+ &self,
+ hidden_states: &Tensor,
+ attention_mask: &Tensor,
+ encoder_hidden_states: Option<&Tensor>,
+ encoder_attention_mask: Option<&Tensor>,
+ past_key_value: Option<(&Tensor, &Tensor)>,
+ ) -> Result<(Tensor, Tensor)> {
+ let self_attention_outputs = self.attention.forward(
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ )?;
+ let attention_output = self_attention_outputs.0;
+ let outputs = self_attention_outputs.1;
+ let intermediate_output = self.intermediate.forward(&attention_output)?;
+ let layer_output = self
+ .output
+ .forward(&intermediate_output, &attention_output)?;
+ Ok((layer_output, outputs))
+ }
+}
+
+struct XLMRobertaEncoder {
+ layers: Vec<XLMRobertaLayer>,
+}
+
+impl XLMRobertaEncoder {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let layers = (0..cfg.num_hidden_layers)
+ .map(|i| XLMRobertaLayer::new(cfg, vb.pp(format!("layer.{}", i))))
+ .collect::<Result<Vec<_>>>()?;
+ Ok(Self { layers })
+ }
+
+ fn forward(
+ &self,
+ hidden_states: &Tensor,
+ attention_mask: &Tensor,
+ encoder_hidden_states: Option<&Tensor>,
+ encoder_attention_mask: Option<&Tensor>,
+ past_key_value: Option<(&Tensor, &Tensor)>,
+ ) -> Result<Tensor> {
+ let mut hidden_states = hidden_states.clone();
+ for layer_module in self.layers.iter() {
+ let layer_outputs = layer_module.forward(
+ &hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ )?;
+ hidden_states = layer_outputs.0;
+ }
+ Ok(hidden_states)
+ }
+}
+
+pub struct XLMRobertaModel {
+ encoder: XLMRobertaEncoder,
+ embeddings: XLMRobertaEmbeddings,
+}
+
+impl XLMRobertaModel {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let encoder = XLMRobertaEncoder::new(cfg, vb.pp("encoder"))?;
+ let embeddings = XLMRobertaEmbeddings::load(vb.pp("embeddings"), cfg)?;
+ Ok(Self {
+ encoder,
+ embeddings,
+ })
+ }
+
+ pub fn forward(
+ &self,
+ input_ids: &Tensor,
+ attention_mask: &Tensor,
+ token_type_ids: &Tensor,
+ past_key_value: Option<(&Tensor, &Tensor)>,
+ encoder_hidden_states: Option<&Tensor>,
+ encoder_attention_mask: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let hidden_states = self.embeddings.forward(input_ids, token_type_ids)?;
+ let attention_mask = prepare_4d_attention_mask(attention_mask, DType::F32, None)?
+ .to_device(hidden_states.device())?;
+ let hidden_states = self.encoder.forward(
+ &hidden_states,
+ &attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ )?;
+ Ok(hidden_states)
+ }
+}
+
+struct XLMRobertaLMHead {
+ dense: Linear,
+ layer_norm: LayerNorm,
+}
+
+impl XLMRobertaLMHead {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
+ let layer_norm =
+ candle_nn::layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("layer_norm"))?;
+ Ok(Self { dense, layer_norm })
+ }
+
+ fn forward(&self, hidden_states: &Tensor, shared_embeddings: &Tensor) -> Result<Tensor> {
+ let hidden_states = self.dense.forward(hidden_states)?;
+ let hidden_states = candle_nn::Activation::Gelu.forward(&hidden_states)?;
+ let hidden_states = self.layer_norm.forward(&hidden_states)?;
+ let hidden_states = hidden_states.broadcast_matmul(shared_embeddings)?;
+ Ok(hidden_states)
+ }
+}
+
+pub struct XLMRobertaForMaskedLM {
+ roberta: XLMRobertaModel,
+ lm_head: XLMRobertaLMHead,
+}
+
+impl XLMRobertaForMaskedLM {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let roberta = XLMRobertaModel::new(cfg, vb.pp("roberta"))?;
+ let lm_head = XLMRobertaLMHead::new(cfg, vb.pp("lm_head"))?;
+ Ok(Self { roberta, lm_head })
+ }
+
+ pub fn forward(
+ &self,
+ input_ids: &Tensor,
+ attention_mask: &Tensor,
+ token_type_ids: &Tensor,
+ past_key_value: Option<(&Tensor, &Tensor)>,
+ encoder_hidden_states: Option<&Tensor>,
+ encoder_attention_mask: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let hidden_states = self.roberta.forward(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ past_key_value,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ )?;
+ let lm_logits = self.lm_head.forward(
+ &hidden_states,
+ &self
+ .roberta
+ .embeddings
+ .word_embeddings
+ .embeddings()
+ .t()?
+ .unsqueeze(0)?,
+ )?;
+ Ok(lm_logits)
+ }
+}
+
+struct XLMRobertaClassificationHead {
+ dense: Linear,
+ out_proj: Linear,
+}
+
+impl XLMRobertaClassificationHead {
+ fn new(num_labels: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
+ let out_proj = linear(cfg.hidden_size, num_labels, vb.pp("out_proj"))?;
+ Ok(Self { dense, out_proj })
+ }
+
+ fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
+ let cls_states = hidden_states.get_on_dim(1, 0)?.contiguous()?;
+ let hidden_states = self.dense.forward(&cls_states)?;
+ let hidden_states = candle_nn::Activation::GeluPytorchTanh.forward(&hidden_states)?;
+ let hidden_states = self.out_proj.forward(&hidden_states)?;
+ Ok(hidden_states)
+ }
+}
+
+pub struct XLMRobertaForSequenceClassification {
+ roberta: XLMRobertaModel,
+ classifier: XLMRobertaClassificationHead,
+}
+
+impl XLMRobertaForSequenceClassification {
+ pub fn new(num_labels: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let roberta = XLMRobertaModel::new(cfg, vb.pp("roberta"))?;
+ let classifier = XLMRobertaClassificationHead::new(num_labels, cfg, vb.pp("classifier"))?;
+ Ok(Self {
+ roberta,
+ classifier,
+ })
+ }
+
+ pub fn forward(
+ &self,
+ input_ids: &Tensor,
+ attention_mask: &Tensor,
+ token_type_ids: &Tensor,
+ ) -> Result<Tensor> {
+ let hidden_states =
+ self.roberta
+ .forward(input_ids, attention_mask, token_type_ids, None, None, None)?;
+ self.classifier.forward(&hidden_states)
+ }
+}
+
+fn prepare_4d_attention_mask(
+ mask: &Tensor,
+ dtype: DType,
+ tgt_len: Option<usize>,
+) -> Result<Tensor> {
+ let bsz = mask.dim(0)?;
+ let src_len = mask.dim(1)?;
+ let tgt_len = tgt_len.unwrap_or(src_len);
+
+ let expanded_mask = mask
+ .unsqueeze(1)?
+ .unsqueeze(2)?
+ .expand((bsz, 1, tgt_len, src_len))?
+ .to_dtype(dtype)?;
+
+ let inverted_mask = (1.0 - expanded_mask)?;
+
+ (inverted_mask * get_dtype_min_val(dtype))?.to_dtype(dtype)
+}
+
+fn get_dtype_min_val(dtype: DType) -> f64 {
+ match dtype {
+ DType::F32 => f32::MIN as f64,
+ DType::F64 => f64::MIN,
+ _ => panic!("Unsupported data type"),
+ }
+}