summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-30 20:20:36 +0100
committerGitHub <noreply@github.com>2023-10-30 19:20:36 +0000
commit392a00a147c26ebe70c6484d72223d02ada6a72a (patch)
tree653d9d4c29f5738f598fb96329db9031561d7057
parent4c967b9184834cd1e166dfdd6d88450d16bad8f2 (diff)
downloadcandle-392a00a147c26ebe70c6484d72223d02ada6a72a.tar.gz
candle-392a00a147c26ebe70c6484d72223d02ada6a72a.tar.bz2
candle-392a00a147c26ebe70c6484d72223d02ada6a72a.zip
Add support for the marian base model. (#1221)
-rw-r--r--candle-examples/examples/marian-mt/main.rs56
-rw-r--r--candle-nn/src/activation.rs2
-rw-r--r--candle-transformers/src/models/marian.rs25
3 files changed, 72 insertions, 11 deletions
diff --git a/candle-examples/examples/marian-mt/main.rs b/candle-examples/examples/marian-mt/main.rs
index c198777c..c503667c 100644
--- a/candle-examples/examples/marian-mt/main.rs
+++ b/candle-examples/examples/marian-mt/main.rs
@@ -5,7 +5,7 @@ extern crate intel_mkl_src;
extern crate accelerate_src;
use anyhow::Error as E;
-use clap::Parser;
+use clap::{Parser, ValueEnum};
use candle::{DType, Tensor};
use candle_nn::VarBuilder;
@@ -13,6 +13,12 @@ use candle_transformers::models::marian;
use tokenizers::Tokenizer;
+#[derive(Clone, Debug, Copy, ValueEnum)]
+enum Which {
+ Base,
+ Big,
+}
+
// TODO: Maybe add support for the conditional prompt.
#[derive(Parser)]
struct Args {
@@ -25,6 +31,10 @@ struct Args {
#[arg(long)]
tokenizer_dec: Option<String>,
+ /// Choose the variant of the model to run.
+ #[arg(long, default_value = "big")]
+ which: Which,
+
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
@@ -42,13 +52,22 @@ pub fn main() -> anyhow::Result<()> {
use hf_hub::api::sync::Api;
let args = Args::parse();
- let config = marian::Config::opus_mt_tc_big_fr_en();
+ let config = match args.which {
+ Which::Base => marian::Config::opus_mt_fr_en(),
+ Which::Big => marian::Config::opus_mt_tc_big_fr_en(),
+ };
let tokenizer = {
let tokenizer = match args.tokenizer {
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
- None => Api::new()?
- .model("lmz/candle-marian".to_string())
- .get("tokenizer-marian-fr.json")?,
+ None => {
+ let name = match args.which {
+ Which::Base => "tokenizer-marian-base-fr.json",
+ Which::Big => "tokenizer-marian-fr.json",
+ };
+ Api::new()?
+ .model("lmz/candle-marian".to_string())
+ .get(name)?
+ }
};
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
};
@@ -56,9 +75,15 @@ pub fn main() -> anyhow::Result<()> {
let tokenizer_dec = {
let tokenizer = match args.tokenizer_dec {
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
- None => Api::new()?
- .model("lmz/candle-marian".to_string())
- .get("tokenizer-marian-en.json")?,
+ None => {
+ let name = match args.which {
+ Which::Base => "tokenizer-marian-base-en.json",
+ Which::Big => "tokenizer-marian-en.json",
+ };
+ Api::new()?
+ .model("lmz/candle-marian".to_string())
+ .get(name)?
+ }
};
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
};
@@ -67,9 +92,18 @@ pub fn main() -> anyhow::Result<()> {
let vb = {
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
- None => Api::new()?
- .model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string())
- .get("model.safetensors")?,
+ None => match args.which {
+ Which::Base => Api::new()?
+ .repo(hf_hub::Repo::with_revision(
+ "Helsinki-NLP/opus-mt-fr-en".to_string(),
+ hf_hub::RepoType::Model,
+ "refs/pr/4".to_string(),
+ ))
+ .get("model.safetensors")?,
+ Which::Big => Api::new()?
+ .model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string())
+ .get("model.safetensors")?,
+ },
};
unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? }
};
diff --git a/candle-nn/src/activation.rs b/candle-nn/src/activation.rs
index 52ceba78..79cf9c82 100644
--- a/candle-nn/src/activation.rs
+++ b/candle-nn/src/activation.rs
@@ -13,6 +13,7 @@ pub enum Activation {
Relu6,
Silu,
Sigmoid,
+ Swish,
Elu(f64),
LeakyRelu(f64),
}
@@ -28,6 +29,7 @@ impl super::Module for Activation {
Self::Relu6 => xs.clamp(0f32, 6f32),
Self::Silu => crate::ops::silu(xs),
Self::Sigmoid => crate::ops::sigmoid(xs),
+ Self::Swish => xs * crate::ops::sigmoid(xs)?,
&Self::Elu(alpha) => xs.elu(alpha),
&Self::LeakyRelu(negative_slope) => crate::ops::leaky_relu(xs, negative_slope),
}
diff --git a/candle-transformers/src/models/marian.rs b/candle-transformers/src/models/marian.rs
index 2bcfd2f7..5305d4d8 100644
--- a/candle-transformers/src/models/marian.rs
+++ b/candle-transformers/src/models/marian.rs
@@ -51,6 +51,31 @@ impl Config {
vocab_size: 53017,
}
}
+
+ // https://huggingface.co/Helsinki-NLP/opus-mt-fr-en/blob/main/config.json
+ pub fn opus_mt_fr_en() -> Self {
+ Self {
+ activation_function: candle_nn::Activation::Swish,
+ d_model: 512,
+ decoder_attention_heads: 8,
+ decoder_ffn_dim: 2048,
+ decoder_layers: 6,
+ decoder_start_token_id: 59513,
+ decoder_vocab_size: Some(59514),
+ encoder_attention_heads: 8,
+ encoder_ffn_dim: 2048,
+ encoder_layers: 6,
+ eos_token_id: 0,
+ forced_eos_token_id: 0,
+ is_encoder_decoder: true,
+ max_position_embeddings: 512,
+ pad_token_id: 59513,
+ scale_embedding: true,
+ share_encoder_decoder_embeddings: true,
+ use_cache: true,
+ vocab_size: 59514,
+ }
+ }
}
#[derive(Debug, Clone)]