summaryrefslogtreecommitdiff
path: root/candle-examples/examples/t5
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-03-28 17:58:06 +0100
committerGitHub <noreply@github.com>2024-03-28 17:58:06 +0100
commitc5092f2c2977dbb0b45d16a869d22f4c2790a1e2 (patch)
treebee051cc2dc977a1b39d5bbecda23f65b56eaac6 /candle-examples/examples/t5
parentcdc8b57b5cf28ad92642b076d67e610bdb958b2d (diff)
downloadcandle-c5092f2c2977dbb0b45d16a869d22f4c2790a1e2.tar.gz
candle-c5092f2c2977dbb0b45d16a869d22f4c2790a1e2.tar.bz2
candle-c5092f2c2977dbb0b45d16a869d22f4c2790a1e2.zip
Add a couple t5 models. (#1958)
Diffstat (limited to 'candle-examples/examples/t5')
-rw-r--r--candle-examples/examples/t5/main.rs22
1 files changed, 19 insertions, 3 deletions
diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs
index 8ef108b6..be6bc6b5 100644
--- a/candle-examples/examples/t5/main.rs
+++ b/candle-examples/examples/t5/main.rs
@@ -12,12 +12,19 @@ use anyhow::{Error as E, Result};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
-use clap::Parser;
+use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
const DTYPE: DType = DType::F32;
+#[derive(Clone, Debug, Copy, ValueEnum)]
+enum Which {
+ T5Base,
+ T5Small,
+ T5_3B,
+}
+
#[derive(Parser, Debug, Clone)]
#[command(author, version, about, long_about = None)]
struct Args {
@@ -71,6 +78,10 @@ struct Args {
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
+
+ /// The model to be used.
+ #[arg(long, default_value = "t5-small")]
+ which: Which,
}
struct T5ModelBuilder {
@@ -82,8 +93,13 @@ struct T5ModelBuilder {
impl T5ModelBuilder {
pub fn load(args: &Args) -> Result<(Self, Tokenizer)> {
let device = candle_examples::device(args.cpu)?;
- let default_model = "t5-small".to_string();
- let default_revision = "refs/pr/15".to_string();
+ let (default_model, default_revision) = match args.which {
+ Which::T5Base => ("t5-base", "main"),
+ Which::T5Small => ("t5-small", "refs/pr/15"),
+ Which::T5_3B => ("t5-3b", "main"),
+ };
+ let default_model = default_model.to_string();
+ let default_revision = default_revision.to_string();
let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_string()),