summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJani Monoses <jani.monoses@gmail.com>2025-01-13 09:39:27 +0200
committerGitHub <noreply@github.com>2025-01-13 08:39:27 +0100
commit461e8c1685e003bdddfd1e7d1aa5092786ca9df5 (patch)
treeab847343b9305176db0ce630246330a7a04e84bd
parent2344c4e4b89dcb57c021459140c3914faa4df603 (diff)
downloadcandle-461e8c1685e003bdddfd1e7d1aa5092786ca9df5.tar.gz
candle-461e8c1685e003bdddfd1e7d1aa5092786ca9df5.tar.bz2
candle-461e8c1685e003bdddfd1e7d1aa5092786ca9df5.zip
ModernBERT model (#2713)
* layer_norm_no_bias * Modernbert model. * Format + cleanup error. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
-rw-r--r--candle-examples/examples/modernbert/README.md12
-rw-r--r--candle-examples/examples/modernbert/main.rs180
-rw-r--r--candle-nn/src/layer_norm.rs9
-rw-r--r--candle-nn/src/lib.rs4
-rw-r--r--candle-transformers/src/models/mod.rs1
-rw-r--r--candle-transformers/src/models/modernbert.rs407
6 files changed, 612 insertions, 1 deletions
diff --git a/candle-examples/examples/modernbert/README.md b/candle-examples/examples/modernbert/README.md
new file mode 100644
index 00000000..4eba2d7d
--- /dev/null
+++ b/candle-examples/examples/modernbert/README.md
@@ -0,0 +1,12 @@
+# candle-modernbert
+
+ModernBERT is a bidirectional encoder-only language model. In this example it is used for the fill-mask task:
+
+## Usage
+
+```bash
+cargo run --example modernbert --release -- --model modern-bert-large --prompt 'The capital of France is [MASK].'
+```
+```markdown
+Sentence: 1 : The capital of France is Paris.
+```
diff --git a/candle-examples/examples/modernbert/main.rs b/candle-examples/examples/modernbert/main.rs
new file mode 100644
index 00000000..122aa995
--- /dev/null
+++ b/candle-examples/examples/modernbert/main.rs
@@ -0,0 +1,180 @@
+use std::path::PathBuf;
+
+use anyhow::{Error as E, Result};
+use candle::{Device, Tensor};
+use candle_nn::VarBuilder;
+use candle_transformers::models::modernbert;
+use clap::{Parser, ValueEnum};
+use hf_hub::{api::sync::Api, Repo, RepoType};
+use tokenizers::{PaddingParams, Tokenizer};
+
+#[derive(Debug, Clone, ValueEnum)]
+enum Model {
+ ModernBertBase,
+ ModernBertLarge,
+}
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+
+ /// Enable tracing (generates a trace-timestamp.json file).
+ #[arg(long)]
+ tracing: bool,
+
+ #[arg(long)]
+ model_id: Option<String>,
+
+ #[arg(long, default_value = "main")]
+ revision: String,
+
+ #[arg(long, default_value = "modern-bert-base")]
+ model: Model,
+
+ // Path to the tokenizer file.
+ #[arg(long)]
+ tokenizer_file: Option<String>,
+
+ // Path to the weight files.
+ #[arg(long)]
+ weight_files: Option<String>,
+
+ // Path to the config file.
+ #[arg(long)]
+ config_file: Option<String>,
+
+ /// When set, compute embeddings for this prompt.
+ #[arg(long)]
+ prompt: Option<String>,
+}
+
+fn main() -> Result<()> {
+ let args = Args::parse();
+ let api = Api::new()?;
+ let model_id = match &args.model_id {
+ Some(model_id) => model_id.to_string(),
+ None => match args.model {
+ Model::ModernBertBase => "answerdotai/ModernBERT-base".to_string(),
+ Model::ModernBertLarge => "answerdotai/ModernBERT-large".to_string(),
+ },
+ };
+ let repo = api.repo(Repo::with_revision(
+ model_id,
+ RepoType::Model,
+ args.revision,
+ ));
+
+ let tokenizer_filename = match args.tokenizer_file {
+ Some(file) => std::path::PathBuf::from(file),
+ None => repo.get("tokenizer.json")?,
+ };
+
+ let config_filename = match args.config_file {
+ Some(file) => std::path::PathBuf::from(file),
+ None => repo.get("config.json")?,
+ };
+
+ let weights_filename = match args.weight_files {
+ Some(files) => PathBuf::from(files),
+ None => match repo.get("model.safetensors") {
+ Ok(safetensors) => safetensors,
+ Err(_) => match repo.get("pytorch_model.bin") {
+ Ok(pytorch_model) => pytorch_model,
+ Err(e) => {
+ anyhow::bail!("Model weights not found. The weights should either be a `model.safetensors` or `pytorch_model.bin` file. Error: {e}")
+ }
+ },
+ },
+ };
+
+ let config = std::fs::read_to_string(config_filename)?;
+ let config: modernbert::Config = serde_json::from_str(&config)?;
+ let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
+
+ let device = candle_examples::device(args.cpu)?;
+
+ let vb = if weights_filename.ends_with("model.safetensors") {
+ unsafe {
+ VarBuilder::from_mmaped_safetensors(&[weights_filename], candle::DType::F32, &device)
+ .unwrap()
+ }
+ } else {
+ println!("Loading weights from pytorch_model.bin");
+ VarBuilder::from_pth(&weights_filename, candle::DType::F32, &device).unwrap()
+ };
+ tokenizer
+ .with_padding(Some(PaddingParams {
+ strategy: tokenizers::PaddingStrategy::BatchLongest,
+ pad_id: config.pad_token_id,
+ ..Default::default()
+ }))
+ .with_truncation(None)
+ .map_err(E::msg)?;
+
+ let prompt = match &args.prompt {
+ Some(p) => vec![p.as_str()],
+ None => vec![
+ "Hello I'm a [MASK] model.",
+ "I'm a [MASK] boy.",
+ "I'm [MASK] in berlin.",
+ "The capital of France is [MASK].",
+ ],
+ };
+ let model = modernbert::ModernBertForMaskedLM::load(vb, &config)?;
+
+ let input_ids = tokenize_batch(&tokenizer, prompt.clone(), &device)?;
+ let attention_mask = get_attention_mask(&tokenizer, prompt.clone(), &device)?;
+
+ let output = model
+ .forward(&input_ids, &attention_mask)?
+ .to_dtype(candle::DType::F32)?;
+
+ let max_outs = output.argmax(2)?;
+
+ let max_out = max_outs.to_vec2::<u32>()?;
+ let max_out_refs: Vec<&[u32]> = max_out.iter().map(|v| v.as_slice()).collect();
+ let decoded = tokenizer.decode_batch(&max_out_refs, true).unwrap();
+ for (i, sentence) in decoded.iter().enumerate() {
+ println!("Sentence: {} : {}", i + 1, sentence);
+ }
+
+ Ok(())
+}
+
+pub fn tokenize_batch(
+ tokenizer: &Tokenizer,
+ input: Vec<&str>,
+ device: &Device,
+) -> anyhow::Result<Tensor> {
+ let tokens = tokenizer.encode_batch(input, true).map_err(E::msg)?;
+
+ let token_ids = tokens
+ .iter()
+ .map(|tokens| {
+ let tokens = tokens.get_ids().to_vec();
+ Tensor::new(tokens.as_slice(), device)
+ })
+ .collect::<candle::Result<Vec<_>>>()?;
+
+ Ok(Tensor::stack(&token_ids, 0)?)
+}
+
+pub fn get_attention_mask(
+ tokenizer: &Tokenizer,
+ input: Vec<&str>,
+ device: &Device,
+) -> anyhow::Result<Tensor> {
+ let tokens = tokenizer.encode_batch(input, true).map_err(E::msg)?;
+
+ let attention_mask = tokens
+ .iter()
+ .map(|tokens| {
+ let tokens = tokens.get_attention_mask().to_vec();
+ Tensor::new(tokens.as_slice(), device)
+ })
+ .collect::<candle::Result<Vec<_>>>()?;
+ Ok(Tensor::stack(&attention_mask, 0)?)
+}
diff --git a/candle-nn/src/layer_norm.rs b/candle-nn/src/layer_norm.rs
index b7dd61cb..468fe24d 100644
--- a/candle-nn/src/layer_norm.rs
+++ b/candle-nn/src/layer_norm.rs
@@ -155,6 +155,15 @@ pub fn layer_norm<C: Into<LayerNormConfig>>(
})
}
+pub fn layer_norm_no_bias(size: usize, eps: f64, vb: crate::VarBuilder) -> Result<LayerNorm> {
+ let config = LayerNormConfig {
+ eps,
+ remove_mean: true,
+ affine: false,
+ };
+ layer_norm(size, config, vb)
+}
+
/// RmsNorm is a specialized version of the LayerNorm module.
#[derive(Clone, Debug)]
pub struct RmsNorm(LayerNorm);
diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs
index eb3cde4a..2113566d 100644
--- a/candle-nn/src/lib.rs
+++ b/candle-nn/src/lib.rs
@@ -46,7 +46,9 @@ pub use embedding::{embedding, Embedding};
pub use func::{func, func_t, Func, FuncT};
pub use group_norm::{group_norm, GroupNorm};
pub use init::Init;
-pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
+pub use layer_norm::{
+ layer_norm, layer_norm_no_bias, rms_norm, LayerNorm, LayerNormConfig, RmsNorm,
+};
pub use linear::{linear, linear_b, linear_no_bias, Linear};
pub use ops::Dropout;
pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD};
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index 5f566991..473a276f 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -60,6 +60,7 @@ pub mod mmdit;
pub mod mobileclip;
pub mod mobilenetv4;
pub mod mobileone;
+pub mod modernbert;
pub mod moondream;
pub mod mpt;
pub mod nvembed_v2;
diff --git a/candle-transformers/src/models/modernbert.rs b/candle-transformers/src/models/modernbert.rs
new file mode 100644
index 00000000..b0ba9b46
--- /dev/null
+++ b/candle-transformers/src/models/modernbert.rs
@@ -0,0 +1,407 @@
+//! ModernBERT
+//!
+//! ModernBERT is a modernized bidirectional encoder-only Transformer model.
+//! - [Arxiv](https://arxiv.org/abs/2412.13663) "Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference"
+//! - Upstream [Github repo](https://github.com/AnswerDotAI/ModernBERT).
+//! - See modernbert in [candle-examples](https://github.com/huggingface/candle/tree/main/candle-examples/) for runnable code
+//!
+
+use candle::{DType, Device, Result, Tensor, D};
+use candle_nn::{
+ embedding, layer_norm_no_bias, linear_no_bias, ops::softmax, Embedding, LayerNorm, Linear,
+ Module, VarBuilder,
+};
+use serde::Deserialize;
+
+use core::f32;
+use std::sync::Arc;
+
+#[derive(Debug, Clone, PartialEq, Deserialize)]
+pub struct Config {
+ pub vocab_size: usize,
+ pub hidden_size: usize,
+ pub num_hidden_layers: usize,
+ pub num_attention_heads: usize,
+ pub intermediate_size: usize,
+ pub max_position_embeddings: usize,
+ pub layer_norm_eps: f64,
+ pub pad_token_id: u32,
+ pub global_attn_every_n_layers: usize,
+ pub global_rope_theta: f64,
+ pub local_attention: usize,
+ pub local_rope_theta: f64,
+}
+
+#[derive(Debug, Clone)]
+struct RotaryEmbedding {
+ sin: Tensor,
+ cos: Tensor,
+}
+
+impl RotaryEmbedding {
+ fn new(dtype: DType, config: &Config, rope_theta: f64, dev: &Device) -> Result<Self> {
+ let dim = config.hidden_size / config.num_attention_heads;
+ let inv_freq: Vec<_> = (0..dim)
+ .step_by(2)
+ .map(|i| 1f32 / rope_theta.powf(i as f64 / dim as f64) as f32)
+ .collect();
+ let inv_freq_len = inv_freq.len();
+ let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
+ let max_seq_len = config.max_position_embeddings;
+ let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
+ .to_dtype(dtype)?
+ .reshape((max_seq_len, 1))?;
+ let freqs = t.matmul(&inv_freq)?;
+ Ok(Self {
+ sin: freqs.sin()?,
+ cos: freqs.cos()?,
+ })
+ }
+
+ fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
+ let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &self.cos, &self.sin)?;
+ let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &self.cos, &self.sin)?;
+ Ok((q_embed, k_embed))
+ }
+}
+
+#[derive(Clone)]
+struct ModernBertAttention {
+ qkv: Linear,
+ proj: Linear,
+ num_attention_heads: usize,
+ attention_head_size: usize,
+ rotary_emb: Arc<RotaryEmbedding>,
+}
+
+impl ModernBertAttention {
+ fn load(vb: VarBuilder, config: &Config, rotary_emb: Arc<RotaryEmbedding>) -> Result<Self> {
+ let num_attention_heads = config.num_attention_heads;
+ let attention_head_size = config.hidden_size / config.num_attention_heads;
+
+ let qkv = linear_no_bias(config.hidden_size, config.hidden_size * 3, vb.pp("Wqkv"))?;
+ let proj = linear_no_bias(config.hidden_size, config.hidden_size, vb.pp("Wo"))?;
+
+ Ok(Self {
+ qkv,
+ proj,
+ num_attention_heads,
+ attention_head_size,
+ rotary_emb,
+ })
+ }
+
+ fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
+ let xs = hidden_states.clone();
+ let (b, seq_len, d) = xs.dims3()?;
+ let qkv = xs
+ .apply(&self.qkv)?
+ .reshape((
+ b,
+ seq_len,
+ 3,
+ self.num_attention_heads,
+ self.attention_head_size,
+ ))?
+ .permute((2, 0, 3, 1, 4))?;
+
+ let q = qkv.get(0)?;
+ let k = qkv.get(1)?;
+ let v = qkv.get(2)?;
+
+ let (q, k) = self.rotary_emb.apply_rotary_emb_qkv(&q, &k)?;
+
+ let scale = (self.attention_head_size as f64).powf(-0.5);
+ let q = (q * scale)?;
+
+ let att = q.matmul(&k.transpose(D::Minus2, D::Minus1)?)?;
+
+ let att = att.broadcast_add(attention_mask)?;
+ let att = softmax(&att, D::Minus1)?;
+
+ let xs = att.matmul(&v)?;
+
+ let xs = xs.transpose(1, 2)?.reshape((b, seq_len, d))?;
+ let xs = xs.apply(&self.proj)?;
+ let xs = xs.reshape((b, seq_len, d))?;
+
+ Ok(xs)
+ }
+}
+
+#[derive(Clone)]
+pub struct ModernBertMLP {
+ wi: Linear,
+ wo: Linear,
+}
+
+impl ModernBertMLP {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let wi = linear_no_bias(
+ config.hidden_size,
+ config.intermediate_size * 2,
+ vb.pp("Wi"),
+ )?;
+ let wo = linear_no_bias(config.intermediate_size, config.hidden_size, vb.pp("Wo"))?;
+ Ok(Self { wi, wo })
+ }
+}
+
+impl Module for ModernBertMLP {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = xs.apply(&self.wi)?;
+ let xs = xs.chunk(2, D::Minus1)?;
+ let xs = (&xs[0].gelu_erf()? * &xs[1])?.apply(&self.wo)?; // GeGLU
+ Ok(xs)
+ }
+}
+
+#[derive(Clone)]
+pub struct ModernBertLayer {
+ attn: ModernBertAttention,
+ mlp: ModernBertMLP,
+ attn_norm: Option<LayerNorm>,
+ mlp_norm: LayerNorm,
+ uses_local_attention: bool,
+}
+
+impl ModernBertLayer {
+ fn load(
+ vb: VarBuilder,
+ config: &Config,
+ rotary_emb: Arc<RotaryEmbedding>,
+ uses_local_attention: bool,
+ ) -> Result<Self> {
+ let attn = ModernBertAttention::load(vb.pp("attn"), config, rotary_emb)?;
+ let mlp = ModernBertMLP::load(vb.pp("mlp"), config)?;
+ let attn_norm = layer_norm_no_bias(
+ config.hidden_size,
+ config.layer_norm_eps,
+ vb.pp("attn_norm"),
+ )
+ .ok();
+ let mlp_norm =
+ layer_norm_no_bias(config.hidden_size, config.layer_norm_eps, vb.pp("mlp_norm"))?;
+ Ok(Self {
+ attn,
+ mlp,
+ attn_norm,
+ mlp_norm,
+ uses_local_attention,
+ })
+ }
+
+ fn forward(
+ &self,
+ xs: &Tensor,
+ global_attention_mask: &Tensor,
+ local_attention_mask: &Tensor,
+ ) -> Result<Tensor> {
+ let residual = xs.clone();
+ let mut xs = xs.clone();
+ if let Some(norm) = &self.attn_norm {
+ xs = xs.apply(norm)?;
+ }
+
+ let attention_mask = if self.uses_local_attention {
+ &global_attention_mask.broadcast_add(local_attention_mask)?
+ } else {
+ global_attention_mask
+ };
+ let xs = self.attn.forward(&xs, attention_mask)?;
+ let xs = (xs + residual)?;
+ let mlp_out = xs.apply(&self.mlp_norm)?.apply(&self.mlp)?;
+ let xs = (xs + mlp_out)?;
+ Ok(xs)
+ }
+}
+
+#[derive(Clone)]
+pub struct ModernBertHead {
+ dense: Linear,
+ norm: LayerNorm,
+}
+
+impl ModernBertHead {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let dense = linear_no_bias(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
+ let norm = layer_norm_no_bias(config.hidden_size, config.layer_norm_eps, vb.pp("norm"))?;
+ Ok(Self { dense, norm })
+ }
+}
+
+impl Module for ModernBertHead {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = xs.apply(&self.dense)?.gelu_erf()?.apply(&self.norm)?;
+ Ok(xs)
+ }
+}
+
+#[derive(Clone)]
+pub struct ModernBertDecoder {
+ decoder: Linear,
+}
+
+impl ModernBertDecoder {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ // The decoder weights are tied with the embeddings layer weights
+ let decoder_weights = vb.get(
+ (config.vocab_size, config.hidden_size),
+ "model.embeddings.tok_embeddings.weight",
+ )?;
+ let decoder_bias = vb.get(config.vocab_size, "decoder.bias")?;
+ let decoder = Linear::new(decoder_weights, Some(decoder_bias));
+ Ok(Self { decoder })
+ }
+}
+
+impl Module for ModernBertDecoder {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = xs.apply(&self.decoder)?;
+ Ok(xs)
+ }
+}
+
+// Global attention mask calculated from padded token inputs
+fn prepare_4d_attention_mask(
+ mask: &Tensor,
+ dtype: DType,
+ tgt_len: Option<usize>,
+) -> Result<Tensor> {
+ let bsz = mask.dim(0)?;
+ let src_len = mask.dim(1)?;
+ let tgt_len = tgt_len.unwrap_or(src_len);
+
+ let expanded_mask = mask
+ .unsqueeze(1)?
+ .unsqueeze(2)?
+ .expand((bsz, 1, tgt_len, src_len))?
+ .to_dtype(dtype)?;
+
+ let inverted_mask = (1.0 - expanded_mask)?;
+
+ (inverted_mask * f32::MIN as f64)?.to_dtype(dtype)
+}
+
+// Attention mask caused by the sliding window
+fn get_local_attention_mask(
+ seq_len: usize,
+ max_distance: usize,
+ device: &Device,
+) -> Result<Tensor> {
+ let mask: Vec<_> = (0..seq_len)
+ .flat_map(|i| {
+ (0..seq_len).map(move |j| {
+ if (j as i32 - i as i32).abs() > max_distance as i32 {
+ f32::NEG_INFINITY
+ } else {
+ 0.
+ }
+ })
+ })
+ .collect();
+ Tensor::from_slice(&mask, (seq_len, seq_len), device)
+}
+
+// ModernBERT backbone
+#[derive(Clone)]
+pub struct ModernBert {
+ word_embeddings: Embedding,
+ norm: LayerNorm,
+ layers: Vec<ModernBertLayer>,
+ final_norm: LayerNorm,
+ head: ModernBertHead,
+ local_attention_size: usize,
+}
+
+impl ModernBert {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let word_embeddings = embedding(
+ config.vocab_size,
+ config.hidden_size,
+ vb.pp("model.embeddings.tok_embeddings"),
+ )?;
+ let norm = layer_norm_no_bias(
+ config.hidden_size,
+ config.layer_norm_eps,
+ vb.pp("model.embeddings.norm"),
+ )?;
+ let global_rotary_emb = Arc::new(RotaryEmbedding::new(
+ vb.dtype(),
+ config,
+ config.global_rope_theta,
+ vb.device(),
+ )?);
+ let local_rotary_emb = Arc::new(RotaryEmbedding::new(
+ vb.dtype(),
+ config,
+ config.local_rope_theta,
+ vb.device(),
+ )?);
+
+ let mut layers = Vec::with_capacity(config.num_hidden_layers);
+ for layer_id in 0..config.num_hidden_layers {
+ let layer_uses_local_attention = layer_id % config.global_attn_every_n_layers != 0;
+ layers.push(ModernBertLayer::load(
+ vb.pp(format!("model.layers.{layer_id}")),
+ config,
+ if layer_uses_local_attention {
+ local_rotary_emb.clone()
+ } else {
+ global_rotary_emb.clone()
+ },
+ layer_uses_local_attention,
+ )?);
+ }
+
+ let final_norm = layer_norm_no_bias(
+ config.hidden_size,
+ config.layer_norm_eps,
+ vb.pp("model.final_norm"),
+ )?;
+ let head = ModernBertHead::load(vb.pp("head"), config)?;
+
+ Ok(Self {
+ word_embeddings,
+ norm,
+ layers,
+ final_norm,
+ head,
+ local_attention_size: config.local_attention,
+ })
+ }
+
+ fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result<Tensor> {
+ let seq_len = xs.shape().dims()[1];
+ let global_attention_mask =
+ prepare_4d_attention_mask(mask, DType::F32, None)?.to_device(xs.device())?;
+ let local_attention_mask =
+ get_local_attention_mask(seq_len, self.local_attention_size / 2, xs.device())?;
+ let mut xs = xs.apply(&self.word_embeddings)?.apply(&self.norm)?;
+ for layer in self.layers.iter() {
+ xs = layer.forward(&xs, &global_attention_mask, &local_attention_mask)?;
+ }
+ let xs = xs.apply(&self.final_norm)?.apply(&self.head)?;
+ Ok(xs)
+ }
+}
+
+// ModernBERT for the fill-mask task
+#[derive(Clone)]
+pub struct ModernBertForMaskedLM {
+ model: ModernBert,
+ decoder: ModernBertDecoder,
+}
+
+impl ModernBertForMaskedLM {
+ pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let model = ModernBert::load(vb.clone(), config)?;
+ let decoder = ModernBertDecoder::load(vb.clone(), config)?;
+ Ok(Self { model, decoder })
+ }
+
+ pub fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result<Tensor> {
+ let xs = self.model.forward(xs, mask)?.apply(&self.decoder)?;
+ Ok(xs)
+ }
+}