//! ModernBERT //! //! ModernBERT is a modernized bidirectional encoder-only Transformer model. //! - [Arxiv](https://arxiv.org/abs/2412.13663) "Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference" //! - Upstream [Github repo](https://github.com/AnswerDotAI/ModernBERT). //! - See modernbert in [candle-examples](https://github.com/huggingface/candle/tree/main/candle-examples/) for runnable code //! use candle::{DType, Device, Result, Tensor, D}; use candle_nn::{ embedding, layer_norm_no_bias, linear_no_bias, ops::softmax, Embedding, LayerNorm, Linear, Module, VarBuilder, }; use serde::Deserialize; use core::f32; use std::sync::Arc; #[derive(Debug, Clone, PartialEq, Deserialize)] pub struct Config { pub vocab_size: usize, pub hidden_size: usize, pub num_hidden_layers: usize, pub num_attention_heads: usize, pub intermediate_size: usize, pub max_position_embeddings: usize, pub layer_norm_eps: f64, pub pad_token_id: u32, pub global_attn_every_n_layers: usize, pub global_rope_theta: f64, pub local_attention: usize, pub local_rope_theta: f64, } #[derive(Debug, Clone)] struct RotaryEmbedding { sin: Tensor, cos: Tensor, } impl RotaryEmbedding { fn new(dtype: DType, config: &Config, rope_theta: f64, dev: &Device) -> Result { let dim = config.hidden_size / config.num_attention_heads; let inv_freq: Vec<_> = (0..dim) .step_by(2) .map(|i| 1f32 / rope_theta.powf(i as f64 / dim as f64) as f32) .collect(); let inv_freq_len = inv_freq.len(); let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; let max_seq_len = config.max_position_embeddings; let t = Tensor::arange(0u32, max_seq_len as u32, dev)? .to_dtype(dtype)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; Ok(Self { sin: freqs.sin()?, cos: freqs.cos()?, }) } fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> { let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &self.cos, &self.sin)?; let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &self.cos, &self.sin)?; Ok((q_embed, k_embed)) } } #[derive(Clone)] struct ModernBertAttention { qkv: Linear, proj: Linear, num_attention_heads: usize, attention_head_size: usize, rotary_emb: Arc, } impl ModernBertAttention { fn load(vb: VarBuilder, config: &Config, rotary_emb: Arc) -> Result { let num_attention_heads = config.num_attention_heads; let attention_head_size = config.hidden_size / config.num_attention_heads; let qkv = linear_no_bias(config.hidden_size, config.hidden_size * 3, vb.pp("Wqkv"))?; let proj = linear_no_bias(config.hidden_size, config.hidden_size, vb.pp("Wo"))?; Ok(Self { qkv, proj, num_attention_heads, attention_head_size, rotary_emb, }) } fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result { let xs = hidden_states.clone(); let (b, seq_len, d) = xs.dims3()?; let qkv = xs .apply(&self.qkv)? .reshape(( b, seq_len, 3, self.num_attention_heads, self.attention_head_size, ))? .permute((2, 0, 3, 1, 4))?; let q = qkv.get(0)?; let k = qkv.get(1)?; let v = qkv.get(2)?; let (q, k) = self.rotary_emb.apply_rotary_emb_qkv(&q, &k)?; let scale = (self.attention_head_size as f64).powf(-0.5); let q = (q * scale)?; let att = q.matmul(&k.transpose(D::Minus2, D::Minus1)?)?; let att = att.broadcast_add(attention_mask)?; let att = softmax(&att, D::Minus1)?; let xs = att.matmul(&v)?; let xs = xs.transpose(1, 2)?.reshape((b, seq_len, d))?; let xs = xs.apply(&self.proj)?; let xs = xs.reshape((b, seq_len, d))?; Ok(xs) } } #[derive(Clone)] pub struct ModernBertMLP { wi: Linear, wo: Linear, } impl ModernBertMLP { fn load(vb: VarBuilder, config: &Config) -> Result { let wi = linear_no_bias( config.hidden_size, config.intermediate_size * 2, vb.pp("Wi"), )?; let wo = linear_no_bias(config.intermediate_size, config.hidden_size, vb.pp("Wo"))?; Ok(Self { wi, wo }) } } impl Module for ModernBertMLP { fn forward(&self, xs: &Tensor) -> Result { let xs = xs.apply(&self.wi)?; let xs = xs.chunk(2, D::Minus1)?; let xs = (&xs[0].gelu_erf()? * &xs[1])?.apply(&self.wo)?; // GeGLU Ok(xs) } } #[derive(Clone)] pub struct ModernBertLayer { attn: ModernBertAttention, mlp: ModernBertMLP, attn_norm: Option, mlp_norm: LayerNorm, uses_local_attention: bool, } impl ModernBertLayer { fn load( vb: VarBuilder, config: &Config, rotary_emb: Arc, uses_local_attention: bool, ) -> Result { let attn = ModernBertAttention::load(vb.pp("attn"), config, rotary_emb)?; let mlp = ModernBertMLP::load(vb.pp("mlp"), config)?; let attn_norm = layer_norm_no_bias( config.hidden_size, config.layer_norm_eps, vb.pp("attn_norm"), ) .ok(); let mlp_norm = layer_norm_no_bias(config.hidden_size, config.layer_norm_eps, vb.pp("mlp_norm"))?; Ok(Self { attn, mlp, attn_norm, mlp_norm, uses_local_attention, }) } fn forward( &self, xs: &Tensor, global_attention_mask: &Tensor, local_attention_mask: &Tensor, ) -> Result { let residual = xs.clone(); let mut xs = xs.clone(); if let Some(norm) = &self.attn_norm { xs = xs.apply(norm)?; } let attention_mask = if self.uses_local_attention { &global_attention_mask.broadcast_add(local_attention_mask)? } else { global_attention_mask }; let xs = self.attn.forward(&xs, attention_mask)?; let xs = (xs + residual)?; let mlp_out = xs.apply(&self.mlp_norm)?.apply(&self.mlp)?; let xs = (xs + mlp_out)?; Ok(xs) } } #[derive(Clone)] pub struct ModernBertHead { dense: Linear, norm: LayerNorm, } impl ModernBertHead { fn load(vb: VarBuilder, config: &Config) -> Result { let dense = linear_no_bias(config.hidden_size, config.hidden_size, vb.pp("dense"))?; let norm = layer_norm_no_bias(config.hidden_size, config.layer_norm_eps, vb.pp("norm"))?; Ok(Self { dense, norm }) } } impl Module for ModernBertHead { fn forward(&self, xs: &Tensor) -> Result { let xs = xs.apply(&self.dense)?.gelu_erf()?.apply(&self.norm)?; Ok(xs) } } #[derive(Clone)] pub struct ModernBertDecoder { decoder: Linear, } impl ModernBertDecoder { fn load(vb: VarBuilder, config: &Config) -> Result { // The decoder weights are tied with the embeddings layer weights let decoder_weights = vb.get( (config.vocab_size, config.hidden_size), "model.embeddings.tok_embeddings.weight", )?; let decoder_bias = vb.get(config.vocab_size, "decoder.bias")?; let decoder = Linear::new(decoder_weights, Some(decoder_bias)); Ok(Self { decoder }) } } impl Module for ModernBertDecoder { fn forward(&self, xs: &Tensor) -> Result { let xs = xs.apply(&self.decoder)?; Ok(xs) } } // Global attention mask calculated from padded token inputs fn prepare_4d_attention_mask( mask: &Tensor, dtype: DType, tgt_len: Option, ) -> Result { let bsz = mask.dim(0)?; let src_len = mask.dim(1)?; let tgt_len = tgt_len.unwrap_or(src_len); let expanded_mask = mask .unsqueeze(1)? .unsqueeze(2)? .expand((bsz, 1, tgt_len, src_len))? .to_dtype(dtype)?; let inverted_mask = (1.0 - expanded_mask)?; (inverted_mask * f32::MIN as f64)?.to_dtype(dtype) } // Attention mask caused by the sliding window fn get_local_attention_mask( seq_len: usize, max_distance: usize, device: &Device, ) -> Result { let mask: Vec<_> = (0..seq_len) .flat_map(|i| { (0..seq_len).map(move |j| { if (j as i32 - i as i32).abs() > max_distance as i32 { f32::NEG_INFINITY } else { 0. } }) }) .collect(); Tensor::from_slice(&mask, (seq_len, seq_len), device) } // ModernBERT backbone #[derive(Clone)] pub struct ModernBert { word_embeddings: Embedding, norm: LayerNorm, layers: Vec, final_norm: LayerNorm, head: ModernBertHead, local_attention_size: usize, } impl ModernBert { fn load(vb: VarBuilder, config: &Config) -> Result { let word_embeddings = embedding( config.vocab_size, config.hidden_size, vb.pp("model.embeddings.tok_embeddings"), )?; let norm = layer_norm_no_bias( config.hidden_size, config.layer_norm_eps, vb.pp("model.embeddings.norm"), )?; let global_rotary_emb = Arc::new(RotaryEmbedding::new( vb.dtype(), config, config.global_rope_theta, vb.device(), )?); let local_rotary_emb = Arc::new(RotaryEmbedding::new( vb.dtype(), config, config.local_rope_theta, vb.device(), )?); let mut layers = Vec::with_capacity(config.num_hidden_layers); for layer_id in 0..config.num_hidden_layers { let layer_uses_local_attention = layer_id % config.global_attn_every_n_layers != 0; layers.push(ModernBertLayer::load( vb.pp(format!("model.layers.{layer_id}")), config, if layer_uses_local_attention { local_rotary_emb.clone() } else { global_rotary_emb.clone() }, layer_uses_local_attention, )?); } let final_norm = layer_norm_no_bias( config.hidden_size, config.layer_norm_eps, vb.pp("model.final_norm"), )?; let head = ModernBertHead::load(vb.pp("head"), config)?; Ok(Self { word_embeddings, norm, layers, final_norm, head, local_attention_size: config.local_attention, }) } fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result { let seq_len = xs.shape().dims()[1]; let global_attention_mask = prepare_4d_attention_mask(mask, DType::F32, None)?.to_device(xs.device())?; let local_attention_mask = get_local_attention_mask(seq_len, self.local_attention_size / 2, xs.device())?; let mut xs = xs.apply(&self.word_embeddings)?.apply(&self.norm)?; for layer in self.layers.iter() { xs = layer.forward(&xs, &global_attention_mask, &local_attention_mask)?; } let xs = xs.apply(&self.final_norm)?.apply(&self.head)?; Ok(xs) } } // ModernBERT for the fill-mask task #[derive(Clone)] pub struct ModernBertForMaskedLM { model: ModernBert, decoder: ModernBertDecoder, } impl ModernBertForMaskedLM { pub fn load(vb: VarBuilder, config: &Config) -> Result { let model = ModernBert::load(vb.clone(), config)?; let decoder = ModernBertDecoder::load(vb.clone(), config)?; Ok(Self { model, decoder }) } pub fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result { let xs = self.model.forward(xs, mask)?.apply(&self.decoder)?; Ok(xs) } }