diff options
Diffstat (limited to 'candle-examples/examples')
-rw-r--r-- | candle-examples/examples/quantized-t5/main.rs | 6 | ||||
-rw-r--r-- | candle-examples/examples/t5/main.rs | 7 |
2 files changed, 11 insertions, 2 deletions
diff --git a/candle-examples/examples/quantized-t5/main.rs b/candle-examples/examples/quantized-t5/main.rs index 5a1cdf0c..0ea2e0bd 100644 --- a/candle-examples/examples/quantized-t5/main.rs +++ b/candle-examples/examples/quantized-t5/main.rs @@ -173,7 +173,11 @@ fn main() -> Result<()> { .to_vec(); let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; let mut model = builder.build_model()?; - let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec(); + let mut output_token_ids = [builder + .config + .decoder_start_token_id + .unwrap_or(builder.config.pad_token_id) as u32] + .to_vec(); let temperature = if args.temperature <= 0. { None } else { diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs index fe59d578..f1c5a94b 100644 --- a/candle-examples/examples/t5/main.rs +++ b/candle-examples/examples/t5/main.rs @@ -172,7 +172,12 @@ fn main() -> Result<()> { println!("Took {:?}", start.elapsed()); } else { let mut model = builder.build_conditional_generation()?; - let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec(); + let mut output_token_ids = [builder + .config + .decoder_start_token_id + .unwrap_or(builder.config.pad_token_id) + as u32] + .to_vec(); if let Some(decoder_prompt) = &args.decoder_prompt { print!("{decoder_prompt}"); output_token_ids.extend( |