diff options
Diffstat (limited to 'candle-examples/examples/t5/main.rs')
-rw-r--r-- | candle-examples/examples/t5/main.rs | 22 |
1 files changed, 19 insertions, 3 deletions
diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs index 8ef108b6..be6bc6b5 100644 --- a/candle-examples/examples/t5/main.rs +++ b/candle-examples/examples/t5/main.rs @@ -12,12 +12,19 @@ use anyhow::{Error as E, Result}; use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; use candle_transformers::generation::LogitsProcessor; -use clap::Parser; +use clap::{Parser, ValueEnum}; use hf_hub::{api::sync::Api, Repo, RepoType}; use tokenizers::Tokenizer; const DTYPE: DType = DType::F32; +#[derive(Clone, Debug, Copy, ValueEnum)] +enum Which { + T5Base, + T5Small, + T5_3B, +} + #[derive(Parser, Debug, Clone)] #[command(author, version, about, long_about = None)] struct Args { @@ -71,6 +78,10 @@ struct Args { /// The context size to consider for the repeat penalty. #[arg(long, default_value_t = 64)] repeat_last_n: usize, + + /// The model to be used. + #[arg(long, default_value = "t5-small")] + which: Which, } struct T5ModelBuilder { @@ -82,8 +93,13 @@ struct T5ModelBuilder { impl T5ModelBuilder { pub fn load(args: &Args) -> Result<(Self, Tokenizer)> { let device = candle_examples::device(args.cpu)?; - let default_model = "t5-small".to_string(); - let default_revision = "refs/pr/15".to_string(); + let (default_model, default_revision) = match args.which { + Which::T5Base => ("t5-base", "main"), + Which::T5Small => ("t5-small", "refs/pr/15"), + Which::T5_3B => ("t5-3b", "main"), + }; + let default_model = default_model.to_string(); + let default_revision = default_revision.to_string(); let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) { (Some(model_id), Some(revision)) => (model_id, revision), (Some(model_id), None) => (model_id, "main".to_string()), |