summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/quantized_t5.rs
diff options
context:
space:
mode:
authorJuarez Bochi <juarez.bochi@grammarly.com>2023-11-06 23:35:37 -0500
committerGitHub <noreply@github.com>2023-11-07 05:35:37 +0100
commit508f811b93035f076e18778fe08106f15abfa8a7 (patch)
treeb31cf3c6bbaf335ab8371f71a1353fe597bab8fd /candle-transformers/src/models/quantized_t5.rs
parenta773a4b22b88d9955f51de552d72717441d49729 (diff)
downloadcandle-508f811b93035f076e18778fe08106f15abfa8a7.tar.gz
candle-508f811b93035f076e18778fe08106f15abfa8a7.tar.bz2
candle-508f811b93035f076e18778fe08106f15abfa8a7.zip
Add support for MADLAD400 (#1285)
* Add support for madlad * Add support for quantized MADLAD
Diffstat (limited to 'candle-transformers/src/models/quantized_t5.rs')
-rw-r--r--candle-transformers/src/models/quantized_t5.rs16
1 files changed, 14 insertions, 2 deletions
diff --git a/candle-transformers/src/models/quantized_t5.rs b/candle-transformers/src/models/quantized_t5.rs
index 4e5bd81a..03f5ef0f 100644
--- a/candle-transformers/src/models/quantized_t5.rs
+++ b/candle-transformers/src/models/quantized_t5.rs
@@ -65,6 +65,7 @@ pub struct Config {
pub use_cache: bool,
pub pad_token_id: usize,
pub eos_token_id: usize,
+ pub decoder_start_token_id: Option<usize>,
}
impl Default for Config {
@@ -89,6 +90,7 @@ impl Default for Config {
use_cache: true,
pad_token_id: 0,
eos_token_id: 1,
+ decoder_start_token_id: Some(0),
}
}
}
@@ -642,7 +644,12 @@ pub struct T5EncoderModel {
impl T5EncoderModel {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
- let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
+ let shared_vb = if vb.contains_key("shared") {
+ vb.pp("shared")
+ } else {
+ vb.pp("decoder").pp("embed_tokens")
+ };
+ let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
let shared = Arc::new(shared);
let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
Ok(Self {
@@ -683,7 +690,12 @@ impl T5ForConditionalGeneration {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
assert!(cfg.is_encoder_decoder);
let d_model = cfg.d_model;
- let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
+ let shared_vb = if vb.contains_key("shared") {
+ vb.pp("shared")
+ } else {
+ vb.pp("decoder").pp("embed_tokens")
+ };
+ let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
let shared = Arc::new(shared);
let mut encoder_cfg = cfg.clone();