summaryrefslogtreecommitdiff
path: root/candle-examples/examples/t5
diff options
context:
space:
mode:
authorJuarez Bochi <juarez.bochi@grammarly.com>2023-11-06 23:35:37 -0500
committerGitHub <noreply@github.com>2023-11-07 05:35:37 +0100
commit508f811b93035f076e18778fe08106f15abfa8a7 (patch)
treeb31cf3c6bbaf335ab8371f71a1353fe597bab8fd /candle-examples/examples/t5
parenta773a4b22b88d9955f51de552d72717441d49729 (diff)
downloadcandle-508f811b93035f076e18778fe08106f15abfa8a7.tar.gz
candle-508f811b93035f076e18778fe08106f15abfa8a7.tar.bz2
candle-508f811b93035f076e18778fe08106f15abfa8a7.zip
Add support for MADLAD400 (#1285)
* Add support for madlad * Add support for quantized MADLAD
Diffstat (limited to 'candle-examples/examples/t5')
-rw-r--r--candle-examples/examples/t5/main.rs7
1 files changed, 6 insertions, 1 deletions
diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs
index fe59d578..f1c5a94b 100644
--- a/candle-examples/examples/t5/main.rs
+++ b/candle-examples/examples/t5/main.rs
@@ -172,7 +172,12 @@ fn main() -> Result<()> {
println!("Took {:?}", start.elapsed());
} else {
let mut model = builder.build_conditional_generation()?;
- let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
+ let mut output_token_ids = [builder
+ .config
+ .decoder_start_token_id
+ .unwrap_or(builder.config.pad_token_id)
+ as u32]
+ .to_vec();
if let Some(decoder_prompt) = &args.decoder_prompt {
print!("{decoder_prompt}");
output_token_ids.extend(