summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/quantized-t5/main.rs6
-rw-r--r--candle-examples/examples/t5/main.rs7
-rw-r--r--candle-transformers/src/models/quantized_t5.rs16
-rw-r--r--candle-transformers/src/models/t5.rs17
-rw-r--r--candle-transformers/src/quantized_var_builder.rs4
5 files changed, 44 insertions, 6 deletions
diff --git a/candle-examples/examples/quantized-t5/main.rs b/candle-examples/examples/quantized-t5/main.rs
index 5a1cdf0c..0ea2e0bd 100644
--- a/candle-examples/examples/quantized-t5/main.rs
+++ b/candle-examples/examples/quantized-t5/main.rs
@@ -173,7 +173,11 @@ fn main() -> Result<()> {
.to_vec();
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
let mut model = builder.build_model()?;
- let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
+ let mut output_token_ids = [builder
+ .config
+ .decoder_start_token_id
+ .unwrap_or(builder.config.pad_token_id) as u32]
+ .to_vec();
let temperature = if args.temperature <= 0. {
None
} else {
diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs
index fe59d578..f1c5a94b 100644
--- a/candle-examples/examples/t5/main.rs
+++ b/candle-examples/examples/t5/main.rs
@@ -172,7 +172,12 @@ fn main() -> Result<()> {
println!("Took {:?}", start.elapsed());
} else {
let mut model = builder.build_conditional_generation()?;
- let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
+ let mut output_token_ids = [builder
+ .config
+ .decoder_start_token_id
+ .unwrap_or(builder.config.pad_token_id)
+ as u32]
+ .to_vec();
if let Some(decoder_prompt) = &args.decoder_prompt {
print!("{decoder_prompt}");
output_token_ids.extend(
diff --git a/candle-transformers/src/models/quantized_t5.rs b/candle-transformers/src/models/quantized_t5.rs
index 4e5bd81a..03f5ef0f 100644
--- a/candle-transformers/src/models/quantized_t5.rs
+++ b/candle-transformers/src/models/quantized_t5.rs
@@ -65,6 +65,7 @@ pub struct Config {
pub use_cache: bool,
pub pad_token_id: usize,
pub eos_token_id: usize,
+ pub decoder_start_token_id: Option<usize>,
}
impl Default for Config {
@@ -89,6 +90,7 @@ impl Default for Config {
use_cache: true,
pad_token_id: 0,
eos_token_id: 1,
+ decoder_start_token_id: Some(0),
}
}
}
@@ -642,7 +644,12 @@ pub struct T5EncoderModel {
impl T5EncoderModel {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
- let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
+ let shared_vb = if vb.contains_key("shared") {
+ vb.pp("shared")
+ } else {
+ vb.pp("decoder").pp("embed_tokens")
+ };
+ let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
let shared = Arc::new(shared);
let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
Ok(Self {
@@ -683,7 +690,12 @@ impl T5ForConditionalGeneration {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
assert!(cfg.is_encoder_decoder);
let d_model = cfg.d_model;
- let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
+ let shared_vb = if vb.contains_key("shared") {
+ vb.pp("shared")
+ } else {
+ vb.pp("decoder").pp("embed_tokens")
+ };
+ let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
let shared = Arc::new(shared);
let mut encoder_cfg = cfg.clone();
diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs
index 1101d001..3069be1c 100644
--- a/candle-transformers/src/models/t5.rs
+++ b/candle-transformers/src/models/t5.rs
@@ -63,6 +63,7 @@ pub struct Config {
pub use_cache: bool,
pub pad_token_id: usize,
pub eos_token_id: usize,
+ pub decoder_start_token_id: Option<usize>,
}
impl Default for Config {
@@ -87,6 +88,7 @@ impl Default for Config {
use_cache: true,
pad_token_id: 0,
eos_token_id: 1,
+ decoder_start_token_id: Some(0),
}
}
}
@@ -110,6 +112,7 @@ impl Config {
num_heads: 12,
num_layers: 12,
pad_token_id: 0,
+ decoder_start_token_id: Some(0),
relative_attention_max_distance: 128,
relative_attention_num_buckets: 32,
use_cache: true,
@@ -667,7 +670,12 @@ pub struct T5EncoderModel {
impl T5EncoderModel {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
- let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
+ let shared_vb = if vb.contains_tensor("shared") {
+ vb.pp("shared")
+ } else {
+ vb.pp("decoder").pp("embed_tokens")
+ };
+ let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
let shared = Arc::new(shared);
let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
Ok(Self {
@@ -708,7 +716,12 @@ impl T5ForConditionalGeneration {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
assert!(cfg.is_encoder_decoder);
let d_model = cfg.d_model;
- let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
+ let shared_vb = if vb.contains_tensor("shared") {
+ vb.pp("shared")
+ } else {
+ vb.pp("decoder").pp("embed_tokens")
+ };
+ let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
let shared = Arc::new(shared);
let mut encoder_cfg = cfg.clone();
diff --git a/candle-transformers/src/quantized_var_builder.rs b/candle-transformers/src/quantized_var_builder.rs
index 810802e8..63101f4c 100644
--- a/candle-transformers/src/quantized_var_builder.rs
+++ b/candle-transformers/src/quantized_var_builder.rs
@@ -90,4 +90,8 @@ impl VarBuilder {
pub fn device(&self) -> &Device {
&self.device
}
+
+ pub fn contains_key(&self, key: &str) -> bool {
+ self.data.contains_key(key)
+ }
}