diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-03-28 17:58:06 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-28 17:58:06 +0100 |
commit | c5092f2c2977dbb0b45d16a869d22f4c2790a1e2 (patch) | |
tree | bee051cc2dc977a1b39d5bbecda23f65b56eaac6 /candle-examples/examples/t5 | |
parent | cdc8b57b5cf28ad92642b076d67e610bdb958b2d (diff) | |
download | candle-c5092f2c2977dbb0b45d16a869d22f4c2790a1e2.tar.gz candle-c5092f2c2977dbb0b45d16a869d22f4c2790a1e2.tar.bz2 candle-c5092f2c2977dbb0b45d16a869d22f4c2790a1e2.zip |
Add a couple t5 models. (#1958)
Diffstat (limited to 'candle-examples/examples/t5')
-rw-r--r-- | candle-examples/examples/t5/main.rs | 22 |
1 files changed, 19 insertions, 3 deletions
diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs index 8ef108b6..be6bc6b5 100644 --- a/candle-examples/examples/t5/main.rs +++ b/candle-examples/examples/t5/main.rs @@ -12,12 +12,19 @@ use anyhow::{Error as E, Result}; use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; use candle_transformers::generation::LogitsProcessor; -use clap::Parser; +use clap::{Parser, ValueEnum}; use hf_hub::{api::sync::Api, Repo, RepoType}; use tokenizers::Tokenizer; const DTYPE: DType = DType::F32; +#[derive(Clone, Debug, Copy, ValueEnum)] +enum Which { + T5Base, + T5Small, + T5_3B, +} + #[derive(Parser, Debug, Clone)] #[command(author, version, about, long_about = None)] struct Args { @@ -71,6 +78,10 @@ struct Args { /// The context size to consider for the repeat penalty. #[arg(long, default_value_t = 64)] repeat_last_n: usize, + + /// The model to be used. + #[arg(long, default_value = "t5-small")] + which: Which, } struct T5ModelBuilder { @@ -82,8 +93,13 @@ struct T5ModelBuilder { impl T5ModelBuilder { pub fn load(args: &Args) -> Result<(Self, Tokenizer)> { let device = candle_examples::device(args.cpu)?; - let default_model = "t5-small".to_string(); - let default_revision = "refs/pr/15".to_string(); + let (default_model, default_revision) = match args.which { + Which::T5Base => ("t5-base", "main"), + Which::T5Small => ("t5-small", "refs/pr/15"), + Which::T5_3B => ("t5-3b", "main"), + }; + let default_model = default_model.to_string(); + let default_revision = default_revision.to_string(); let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) { (Some(model_id), Some(revision)) => (model_id, revision), (Some(model_id), None) => (model_id, "main".to_string()), |