summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-01 20:58:34 +0200
committerGitHub <noreply@github.com>2024-04-01 20:58:34 +0200
commitbe9c200cbb16b59fe1f1e8c0f606981412c9b757 (patch)
tree4a2cdbb5d388f4cfaa42b53a9297575174baa2d3
parentea0d8d3753b53a936c472c30ae5dc0d52bfa81fa (diff)
downloadcandle-be9c200cbb16b59fe1f1e8c0f606981412c9b757.tar.gz
candle-be9c200cbb16b59fe1f1e8c0f606981412c9b757.tar.bz2
candle-be9c200cbb16b59fe1f1e8c0f606981412c9b757.zip
Expose the t5 config fields + allow t5-large. (#1987)
-rw-r--r--candle-examples/examples/t5/main.rs2
-rw-r--r--candle-transformers/src/models/t5.rs32
2 files changed, 18 insertions, 16 deletions
diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs
index 34ae0ead..902282c1 100644
--- a/candle-examples/examples/t5/main.rs
+++ b/candle-examples/examples/t5/main.rs
@@ -22,6 +22,7 @@ const DTYPE: DType = DType::F32;
enum Which {
T5Base,
T5Small,
+ T5Large,
T5_3B,
Mt5Base,
Mt5Small,
@@ -108,6 +109,7 @@ impl T5ModelBuilder {
let (default_model, default_revision) = match args.which {
Which::T5Base => ("t5-base", "main"),
Which::T5Small => ("t5-small", "refs/pr/15"),
+ Which::T5Large => ("t5-large", "main"),
Which::T5_3B => ("t5-3b", "main"),
Which::Mt5Base => ("google/mt5-base", "refs/pr/5"),
Which::Mt5Small => ("google/mt5-small", "refs/pr/6"),
diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs
index 5dc44cb5..f4b5b4b0 100644
--- a/candle-transformers/src/models/t5.rs
+++ b/candle-transformers/src/models/t5.rs
@@ -70,26 +70,26 @@ where
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
- vocab_size: usize,
- d_model: usize,
- d_kv: usize,
- d_ff: usize,
- num_layers: usize,
- num_decoder_layers: Option<usize>,
- num_heads: usize,
- relative_attention_num_buckets: usize,
+ pub vocab_size: usize,
+ pub d_model: usize,
+ pub d_kv: usize,
+ pub d_ff: usize,
+ pub num_layers: usize,
+ pub num_decoder_layers: Option<usize>,
+ pub num_heads: usize,
+ pub relative_attention_num_buckets: usize,
#[serde(default = "default_relative_attention_max_distance")]
- relative_attention_max_distance: usize,
- dropout_rate: f64,
- layer_norm_epsilon: f64,
- initializer_factor: f64,
+ pub relative_attention_max_distance: usize,
+ pub dropout_rate: f64,
+ pub layer_norm_epsilon: f64,
+ pub initializer_factor: f64,
#[serde(default, deserialize_with = "deserialize_feed_forward_proj_activation")]
- feed_forward_proj: ActivationWithOptionalGating,
+ pub feed_forward_proj: ActivationWithOptionalGating,
#[serde(default = "default_tie_word_embeddings")]
- tie_word_embeddings: bool,
+ pub tie_word_embeddings: bool,
#[serde(default = "default_is_decoder")]
- is_decoder: bool,
- is_encoder_decoder: bool,
+ pub is_decoder: bool,
+ pub is_encoder_decoder: bool,
#[serde(default = "default_use_cache")]
pub use_cache: bool,
pub pad_token_id: usize,