summaryrefslogtreecommitdiff
path: root/candle-examples/examples/quantized-t5/main.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-21 13:23:30 +0100
committerGitHub <noreply@github.com>2023-09-21 13:23:30 +0100
commitb43ca493f67a98aa6a6f53144ecb17a0a0d20fd0 (patch)
tree893f80ec5aaf0a9c90d823870bb964d430843696 /candle-examples/examples/quantized-t5/main.rs
parent3b557765e8e1641d1289d33b177938abe10d24d2 (diff)
downloadcandle-b43ca493f67a98aa6a6f53144ecb17a0a0d20fd0.tar.gz
candle-b43ca493f67a98aa6a6f53144ecb17a0a0d20fd0.tar.bz2
candle-b43ca493f67a98aa6a6f53144ecb17a0a0d20fd0.zip
Add more quantized flan t5 variants (#923)
* Add the quantized flan-t5-large variant. * Add more sizes.
Diffstat (limited to 'candle-examples/examples/quantized-t5/main.rs')
-rw-r--r--candle-examples/examples/quantized-t5/main.rs34
1 files changed, 31 insertions, 3 deletions
diff --git a/candle-examples/examples/quantized-t5/main.rs b/candle-examples/examples/quantized-t5/main.rs
index 86d3762e..93a86309 100644
--- a/candle-examples/examples/quantized-t5/main.rs
+++ b/candle-examples/examples/quantized-t5/main.rs
@@ -11,10 +11,20 @@ use candle_transformers::models::quantized_t5 as t5;
use anyhow::{Error as E, Result};
use candle::{Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
-use clap::Parser;
+use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
+#[derive(Clone, Debug, Copy, ValueEnum)]
+enum Which {
+ T5Small,
+ FlanT5Small,
+ FlanT5Base,
+ FlanT5Large,
+ FlanT5Xl,
+ FlanT5Xxl,
+}
+
#[derive(Parser, Debug, Clone)]
#[command(author, version, about, long_about = None)]
struct Args {
@@ -55,6 +65,10 @@ struct Args {
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
+
+ /// The model size to use.
+ #[arg(long, default_value = "t5-small")]
+ which: Which,
}
struct T5ModelBuilder {
@@ -77,11 +91,25 @@ impl T5ModelBuilder {
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let api = Api::new()?;
let api = api.repo(repo);
- let config_filename = api.get("config.json")?;
+ let config_filename = match args.which {
+ Which::T5Small => api.get("config.json")?,
+ Which::FlanT5Small => api.get("config-flan-t5-small.json")?,
+ Which::FlanT5Base => api.get("config-flan-t5-base.json")?,
+ Which::FlanT5Large => api.get("config-flan-t5-large.json")?,
+ Which::FlanT5Xl => api.get("config-flan-t5-xl.json")?,
+ Which::FlanT5Xxl => api.get("config-flan-t5-xxl.json")?,
+ };
let tokenizer_filename = api.get("tokenizer.json")?;
let weights_filename = match &args.weight_file {
Some(filename) => std::path::PathBuf::from(filename),
- None => api.get("model.gguf")?,
+ None => match args.which {
+ Which::T5Small => api.get("model.gguf")?,
+ Which::FlanT5Small => api.get("model-flan-t5-small.gguf")?,
+ Which::FlanT5Base => api.get("model-flan-t5-base.gguf")?,
+ Which::FlanT5Large => api.get("model-flan-t5-large.gguf")?,
+ Which::FlanT5Xl => api.get("model-flan-t5-xl.gguf")?,
+ Which::FlanT5Xxl => api.get("model-flan-t5-xxl.gguf")?,
+ },
};
let config = std::fs::read_to_string(config_filename)?;
let mut config: t5::Config = serde_json::from_str(&config)?;