diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-04-01 20:58:34 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-01 20:58:34 +0200 |
commit | be9c200cbb16b59fe1f1e8c0f606981412c9b757 (patch) | |
tree | 4a2cdbb5d388f4cfaa42b53a9297575174baa2d3 | |
parent | ea0d8d3753b53a936c472c30ae5dc0d52bfa81fa (diff) | |
download | candle-be9c200cbb16b59fe1f1e8c0f606981412c9b757.tar.gz candle-be9c200cbb16b59fe1f1e8c0f606981412c9b757.tar.bz2 candle-be9c200cbb16b59fe1f1e8c0f606981412c9b757.zip |
Expose the t5 config fields + allow t5-large. (#1987)
-rw-r--r-- | candle-examples/examples/t5/main.rs | 2 | ||||
-rw-r--r-- | candle-transformers/src/models/t5.rs | 32 |
2 files changed, 18 insertions, 16 deletions
diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs index 34ae0ead..902282c1 100644 --- a/candle-examples/examples/t5/main.rs +++ b/candle-examples/examples/t5/main.rs @@ -22,6 +22,7 @@ const DTYPE: DType = DType::F32; enum Which { T5Base, T5Small, + T5Large, T5_3B, Mt5Base, Mt5Small, @@ -108,6 +109,7 @@ impl T5ModelBuilder { let (default_model, default_revision) = match args.which { Which::T5Base => ("t5-base", "main"), Which::T5Small => ("t5-small", "refs/pr/15"), + Which::T5Large => ("t5-large", "main"), Which::T5_3B => ("t5-3b", "main"), Which::Mt5Base => ("google/mt5-base", "refs/pr/5"), Which::Mt5Small => ("google/mt5-small", "refs/pr/6"), diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index 5dc44cb5..f4b5b4b0 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -70,26 +70,26 @@ where #[derive(Debug, Clone, PartialEq, Deserialize)] pub struct Config { - vocab_size: usize, - d_model: usize, - d_kv: usize, - d_ff: usize, - num_layers: usize, - num_decoder_layers: Option<usize>, - num_heads: usize, - relative_attention_num_buckets: usize, + pub vocab_size: usize, + pub d_model: usize, + pub d_kv: usize, + pub d_ff: usize, + pub num_layers: usize, + pub num_decoder_layers: Option<usize>, + pub num_heads: usize, + pub relative_attention_num_buckets: usize, #[serde(default = "default_relative_attention_max_distance")] - relative_attention_max_distance: usize, - dropout_rate: f64, - layer_norm_epsilon: f64, - initializer_factor: f64, + pub relative_attention_max_distance: usize, + pub dropout_rate: f64, + pub layer_norm_epsilon: f64, + pub initializer_factor: f64, #[serde(default, deserialize_with = "deserialize_feed_forward_proj_activation")] - feed_forward_proj: ActivationWithOptionalGating, + pub feed_forward_proj: ActivationWithOptionalGating, #[serde(default = "default_tie_word_embeddings")] - tie_word_embeddings: bool, + pub tie_word_embeddings: bool, #[serde(default = "default_is_decoder")] - is_decoder: bool, - is_encoder_decoder: bool, + pub is_decoder: bool, + pub is_encoder_decoder: bool, #[serde(default = "default_use_cache")] pub use_cache: bool, pub pad_token_id: usize, |