diff options
-rw-r--r-- | candle-examples/examples/quantized-t5/main.rs | 6 | ||||
-rw-r--r-- | candle-examples/examples/t5/main.rs | 7 | ||||
-rw-r--r-- | candle-transformers/src/models/quantized_t5.rs | 16 | ||||
-rw-r--r-- | candle-transformers/src/models/t5.rs | 17 | ||||
-rw-r--r-- | candle-transformers/src/quantized_var_builder.rs | 4 |
5 files changed, 44 insertions, 6 deletions
diff --git a/candle-examples/examples/quantized-t5/main.rs b/candle-examples/examples/quantized-t5/main.rs index 5a1cdf0c..0ea2e0bd 100644 --- a/candle-examples/examples/quantized-t5/main.rs +++ b/candle-examples/examples/quantized-t5/main.rs @@ -173,7 +173,11 @@ fn main() -> Result<()> { .to_vec(); let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; let mut model = builder.build_model()?; - let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec(); + let mut output_token_ids = [builder + .config + .decoder_start_token_id + .unwrap_or(builder.config.pad_token_id) as u32] + .to_vec(); let temperature = if args.temperature <= 0. { None } else { diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs index fe59d578..f1c5a94b 100644 --- a/candle-examples/examples/t5/main.rs +++ b/candle-examples/examples/t5/main.rs @@ -172,7 +172,12 @@ fn main() -> Result<()> { println!("Took {:?}", start.elapsed()); } else { let mut model = builder.build_conditional_generation()?; - let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec(); + let mut output_token_ids = [builder + .config + .decoder_start_token_id + .unwrap_or(builder.config.pad_token_id) + as u32] + .to_vec(); if let Some(decoder_prompt) = &args.decoder_prompt { print!("{decoder_prompt}"); output_token_ids.extend( diff --git a/candle-transformers/src/models/quantized_t5.rs b/candle-transformers/src/models/quantized_t5.rs index 4e5bd81a..03f5ef0f 100644 --- a/candle-transformers/src/models/quantized_t5.rs +++ b/candle-transformers/src/models/quantized_t5.rs @@ -65,6 +65,7 @@ pub struct Config { pub use_cache: bool, pub pad_token_id: usize, pub eos_token_id: usize, + pub decoder_start_token_id: Option<usize>, } impl Default for Config { @@ -89,6 +90,7 @@ impl Default for Config { use_cache: true, pad_token_id: 0, eos_token_id: 1, + decoder_start_token_id: Some(0), } } } @@ -642,7 +644,12 @@ pub struct T5EncoderModel { impl T5EncoderModel { pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { - let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?; + let shared_vb = if vb.contains_key("shared") { + vb.pp("shared") + } else { + vb.pp("decoder").pp("embed_tokens") + }; + let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?; let shared = Arc::new(shared); let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?; Ok(Self { @@ -683,7 +690,12 @@ impl T5ForConditionalGeneration { pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { assert!(cfg.is_encoder_decoder); let d_model = cfg.d_model; - let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?; + let shared_vb = if vb.contains_key("shared") { + vb.pp("shared") + } else { + vb.pp("decoder").pp("embed_tokens") + }; + let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?; let shared = Arc::new(shared); let mut encoder_cfg = cfg.clone(); diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index 1101d001..3069be1c 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -63,6 +63,7 @@ pub struct Config { pub use_cache: bool, pub pad_token_id: usize, pub eos_token_id: usize, + pub decoder_start_token_id: Option<usize>, } impl Default for Config { @@ -87,6 +88,7 @@ impl Default for Config { use_cache: true, pad_token_id: 0, eos_token_id: 1, + decoder_start_token_id: Some(0), } } } @@ -110,6 +112,7 @@ impl Config { num_heads: 12, num_layers: 12, pad_token_id: 0, + decoder_start_token_id: Some(0), relative_attention_max_distance: 128, relative_attention_num_buckets: 32, use_cache: true, @@ -667,7 +670,12 @@ pub struct T5EncoderModel { impl T5EncoderModel { pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { - let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?; + let shared_vb = if vb.contains_tensor("shared") { + vb.pp("shared") + } else { + vb.pp("decoder").pp("embed_tokens") + }; + let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?; let shared = Arc::new(shared); let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?; Ok(Self { @@ -708,7 +716,12 @@ impl T5ForConditionalGeneration { pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { assert!(cfg.is_encoder_decoder); let d_model = cfg.d_model; - let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?; + let shared_vb = if vb.contains_tensor("shared") { + vb.pp("shared") + } else { + vb.pp("decoder").pp("embed_tokens") + }; + let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?; let shared = Arc::new(shared); let mut encoder_cfg = cfg.clone(); diff --git a/candle-transformers/src/quantized_var_builder.rs b/candle-transformers/src/quantized_var_builder.rs index 810802e8..63101f4c 100644 --- a/candle-transformers/src/quantized_var_builder.rs +++ b/candle-transformers/src/quantized_var_builder.rs @@ -90,4 +90,8 @@ impl VarBuilder { pub fn device(&self) -> &Device { &self.device } + + pub fn contains_key(&self, key: &str) -> bool { + self.data.contains_key(key) + } } |