summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/quantized_t5.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/quantized_t5.rs')
-rw-r--r--candle-transformers/src/models/quantized_t5.rs4
1 files changed, 2 insertions, 2 deletions
diff --git a/candle-transformers/src/models/quantized_t5.rs b/candle-transformers/src/models/quantized_t5.rs
index 03f5ef0f..8d03ec44 100644
--- a/candle-transformers/src/models/quantized_t5.rs
+++ b/candle-transformers/src/models/quantized_t5.rs
@@ -644,7 +644,7 @@ pub struct T5EncoderModel {
impl T5EncoderModel {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
- let shared_vb = if vb.contains_key("shared") {
+ let shared_vb = if vb.contains_key("shared.weight") {
vb.pp("shared")
} else {
vb.pp("decoder").pp("embed_tokens")
@@ -690,7 +690,7 @@ impl T5ForConditionalGeneration {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
assert!(cfg.is_encoder_decoder);
let d_model = cfg.d_model;
- let shared_vb = if vb.contains_key("shared") {
+ let shared_vb = if vb.contains_key("shared.weight") {
vb.pp("shared")
} else {
vb.pp("decoder").pp("embed_tokens")