summaryrefslogtreecommitdiff
path: root/candle-examples
diff options
context:
space:
mode:
authorJani Monoses <jani.monoses@gmail.com>2025-01-13 09:39:27 +0200
committerGitHub <noreply@github.com>2025-01-13 08:39:27 +0100
commit461e8c1685e003bdddfd1e7d1aa5092786ca9df5 (patch)
treeab847343b9305176db0ce630246330a7a04e84bd /candle-examples
parent2344c4e4b89dcb57c021459140c3914faa4df603 (diff)
downloadcandle-461e8c1685e003bdddfd1e7d1aa5092786ca9df5.tar.gz
candle-461e8c1685e003bdddfd1e7d1aa5092786ca9df5.tar.bz2
candle-461e8c1685e003bdddfd1e7d1aa5092786ca9df5.zip
ModernBERT model (#2713)
* layer_norm_no_bias * Modernbert model. * Format + cleanup error. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-examples')
-rw-r--r--candle-examples/examples/modernbert/README.md12
-rw-r--r--candle-examples/examples/modernbert/main.rs180
2 files changed, 192 insertions, 0 deletions
diff --git a/candle-examples/examples/modernbert/README.md b/candle-examples/examples/modernbert/README.md
new file mode 100644
index 00000000..4eba2d7d
--- /dev/null
+++ b/candle-examples/examples/modernbert/README.md
@@ -0,0 +1,12 @@
+# candle-modernbert
+
+ModernBERT is a bidirectional encoder-only language model. In this example it is used for the fill-mask task:
+
+## Usage
+
+```bash
+cargo run --example modernbert --release -- --model modern-bert-large --prompt 'The capital of France is [MASK].'
+```
+```markdown
+Sentence: 1 : The capital of France is Paris.
+```
diff --git a/candle-examples/examples/modernbert/main.rs b/candle-examples/examples/modernbert/main.rs
new file mode 100644
index 00000000..122aa995
--- /dev/null
+++ b/candle-examples/examples/modernbert/main.rs
@@ -0,0 +1,180 @@
+use std::path::PathBuf;
+
+use anyhow::{Error as E, Result};
+use candle::{Device, Tensor};
+use candle_nn::VarBuilder;
+use candle_transformers::models::modernbert;
+use clap::{Parser, ValueEnum};
+use hf_hub::{api::sync::Api, Repo, RepoType};
+use tokenizers::{PaddingParams, Tokenizer};
+
+#[derive(Debug, Clone, ValueEnum)]
+enum Model {
+ ModernBertBase,
+ ModernBertLarge,
+}
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+
+ /// Enable tracing (generates a trace-timestamp.json file).
+ #[arg(long)]
+ tracing: bool,
+
+ #[arg(long)]
+ model_id: Option<String>,
+
+ #[arg(long, default_value = "main")]
+ revision: String,
+
+ #[arg(long, default_value = "modern-bert-base")]
+ model: Model,
+
+ // Path to the tokenizer file.
+ #[arg(long)]
+ tokenizer_file: Option<String>,
+
+ // Path to the weight files.
+ #[arg(long)]
+ weight_files: Option<String>,
+
+ // Path to the config file.
+ #[arg(long)]
+ config_file: Option<String>,
+
+ /// When set, compute embeddings for this prompt.
+ #[arg(long)]
+ prompt: Option<String>,
+}
+
+fn main() -> Result<()> {
+ let args = Args::parse();
+ let api = Api::new()?;
+ let model_id = match &args.model_id {
+ Some(model_id) => model_id.to_string(),
+ None => match args.model {
+ Model::ModernBertBase => "answerdotai/ModernBERT-base".to_string(),
+ Model::ModernBertLarge => "answerdotai/ModernBERT-large".to_string(),
+ },
+ };
+ let repo = api.repo(Repo::with_revision(
+ model_id,
+ RepoType::Model,
+ args.revision,
+ ));
+
+ let tokenizer_filename = match args.tokenizer_file {
+ Some(file) => std::path::PathBuf::from(file),
+ None => repo.get("tokenizer.json")?,
+ };
+
+ let config_filename = match args.config_file {
+ Some(file) => std::path::PathBuf::from(file),
+ None => repo.get("config.json")?,
+ };
+
+ let weights_filename = match args.weight_files {
+ Some(files) => PathBuf::from(files),
+ None => match repo.get("model.safetensors") {
+ Ok(safetensors) => safetensors,
+ Err(_) => match repo.get("pytorch_model.bin") {
+ Ok(pytorch_model) => pytorch_model,
+ Err(e) => {
+ anyhow::bail!("Model weights not found. The weights should either be a `model.safetensors` or `pytorch_model.bin` file. Error: {e}")
+ }
+ },
+ },
+ };
+
+ let config = std::fs::read_to_string(config_filename)?;
+ let config: modernbert::Config = serde_json::from_str(&config)?;
+ let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
+
+ let device = candle_examples::device(args.cpu)?;
+
+ let vb = if weights_filename.ends_with("model.safetensors") {
+ unsafe {
+ VarBuilder::from_mmaped_safetensors(&[weights_filename], candle::DType::F32, &device)
+ .unwrap()
+ }
+ } else {
+ println!("Loading weights from pytorch_model.bin");
+ VarBuilder::from_pth(&weights_filename, candle::DType::F32, &device).unwrap()
+ };
+ tokenizer
+ .with_padding(Some(PaddingParams {
+ strategy: tokenizers::PaddingStrategy::BatchLongest,
+ pad_id: config.pad_token_id,
+ ..Default::default()
+ }))
+ .with_truncation(None)
+ .map_err(E::msg)?;
+
+ let prompt = match &args.prompt {
+ Some(p) => vec![p.as_str()],
+ None => vec![
+ "Hello I'm a [MASK] model.",
+ "I'm a [MASK] boy.",
+ "I'm [MASK] in berlin.",
+ "The capital of France is [MASK].",
+ ],
+ };
+ let model = modernbert::ModernBertForMaskedLM::load(vb, &config)?;
+
+ let input_ids = tokenize_batch(&tokenizer, prompt.clone(), &device)?;
+ let attention_mask = get_attention_mask(&tokenizer, prompt.clone(), &device)?;
+
+ let output = model
+ .forward(&input_ids, &attention_mask)?
+ .to_dtype(candle::DType::F32)?;
+
+ let max_outs = output.argmax(2)?;
+
+ let max_out = max_outs.to_vec2::<u32>()?;
+ let max_out_refs: Vec<&[u32]> = max_out.iter().map(|v| v.as_slice()).collect();
+ let decoded = tokenizer.decode_batch(&max_out_refs, true).unwrap();
+ for (i, sentence) in decoded.iter().enumerate() {
+ println!("Sentence: {} : {}", i + 1, sentence);
+ }
+
+ Ok(())
+}
+
+pub fn tokenize_batch(
+ tokenizer: &Tokenizer,
+ input: Vec<&str>,
+ device: &Device,
+) -> anyhow::Result<Tensor> {
+ let tokens = tokenizer.encode_batch(input, true).map_err(E::msg)?;
+
+ let token_ids = tokens
+ .iter()
+ .map(|tokens| {
+ let tokens = tokens.get_ids().to_vec();
+ Tensor::new(tokens.as_slice(), device)
+ })
+ .collect::<candle::Result<Vec<_>>>()?;
+
+ Ok(Tensor::stack(&token_ids, 0)?)
+}
+
+pub fn get_attention_mask(
+ tokenizer: &Tokenizer,
+ input: Vec<&str>,
+ device: &Device,
+) -> anyhow::Result<Tensor> {
+ let tokens = tokenizer.encode_batch(input, true).map_err(E::msg)?;
+
+ let attention_mask = tokens
+ .iter()
+ .map(|tokens| {
+ let tokens = tokens.get_attention_mask().to_vec();
+ Tensor::new(tokens.as_slice(), device)
+ })
+ .collect::<candle::Result<Vec<_>>>()?;
+ Ok(Tensor::stack(&attention_mask, 0)?)
+}