diff options
author | Juarez Bochi <juarez.bochi@grammarly.com> | 2023-11-08 11:55:46 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-08 17:55:46 +0100 |
commit | f772213e844fdfcc8dbaf662fc11819f4028dc78 (patch) | |
tree | 28d01a38a9f3bcb6ad5841cb6bf4ab81f54db505 /candle-transformers/src/models/quantized_t5.rs | |
parent | 2feb0b054f96e4c4c87f01b243e749896c94f8c7 (diff) | |
download | candle-f772213e844fdfcc8dbaf662fc11819f4028dc78.tar.gz candle-f772213e844fdfcc8dbaf662fc11819f4028dc78.tar.bz2 candle-f772213e844fdfcc8dbaf662fc11819f4028dc78.zip |
Fix bug introduced in madlad PR (#1298)
Diffstat (limited to 'candle-transformers/src/models/quantized_t5.rs')
-rw-r--r-- | candle-transformers/src/models/quantized_t5.rs | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/candle-transformers/src/models/quantized_t5.rs b/candle-transformers/src/models/quantized_t5.rs index 03f5ef0f..8d03ec44 100644 --- a/candle-transformers/src/models/quantized_t5.rs +++ b/candle-transformers/src/models/quantized_t5.rs @@ -644,7 +644,7 @@ pub struct T5EncoderModel { impl T5EncoderModel { pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { - let shared_vb = if vb.contains_key("shared") { + let shared_vb = if vb.contains_key("shared.weight") { vb.pp("shared") } else { vb.pp("decoder").pp("embed_tokens") @@ -690,7 +690,7 @@ impl T5ForConditionalGeneration { pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { assert!(cfg.is_encoder_decoder); let d_model = cfg.d_model; - let shared_vb = if vb.contains_key("shared") { + let shared_vb = if vb.contains_key("shared.weight") { vb.pp("shared") } else { vb.pp("decoder").pp("embed_tokens") |