diff options
Diffstat (limited to 'candle-examples/examples')
-rw-r--r-- | candle-examples/examples/blip/main.rs | 2 | ||||
-rw-r--r-- | candle-examples/examples/marian-mt/main.rs | 23 |
2 files changed, 15 insertions, 10 deletions
diff --git a/candle-examples/examples/blip/main.rs b/candle-examples/examples/blip/main.rs index 45300feb..a1051a8e 100644 --- a/candle-examples/examples/blip/main.rs +++ b/candle-examples/examples/blip/main.rs @@ -149,6 +149,6 @@ pub fn main() -> anyhow::Result<()> { if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? { print!("{rest}"); } - + println!(); Ok(()) } diff --git a/candle-examples/examples/marian-mt/main.rs b/candle-examples/examples/marian-mt/main.rs index c503667c..89b3a9a3 100644 --- a/candle-examples/examples/marian-mt/main.rs +++ b/candle-examples/examples/marian-mt/main.rs @@ -8,6 +8,7 @@ use anyhow::Error as E; use clap::{Parser, ValueEnum}; use candle::{DType, Tensor}; +use candle_examples::token_output_stream::TokenOutputStream; use candle_nn::VarBuilder; use candle_transformers::models::marian; @@ -87,6 +88,7 @@ pub fn main() -> anyhow::Result<()> { }; Tokenizer::from_file(&tokenizer).map_err(E::msg)? }; + let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec); let device = candle_examples::device(args.cpu)?; let vb = { @@ -107,7 +109,7 @@ pub fn main() -> anyhow::Result<()> { }; unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? } }; - let model = marian::MTModel::new(&config, vb)?; + let mut model = marian::MTModel::new(&config, vb)?; let mut logits_processor = candle_transformers::generation::LogitsProcessor::new(1337, None, None); @@ -125,23 +127,26 @@ pub fn main() -> anyhow::Result<()> { let mut token_ids = vec![config.decoder_start_token_id]; for index in 0..1000 { - // TODO: Add a kv cache. - let context_size = if index >= 1000 { 1 } else { token_ids.len() }; + let context_size = if index >= 1 { 1 } else { token_ids.len() }; let start_pos = token_ids.len().saturating_sub(context_size); let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?; - let logits = model.decode(&input_ids, &encoder_xs)?; + let logits = model.decode(&input_ids, &encoder_xs, start_pos)?; let logits = logits.squeeze(0)?; let logits = logits.get(logits.dim(0)? - 1)?; let token = logits_processor.sample(&logits)?; token_ids.push(token); - println!("{token}"); + if let Some(t) = tokenizer_dec.next_token(token)? { + use std::io::Write; + print!("{t}"); + std::io::stdout().flush()?; + } if token == config.eos_token_id || token == config.forced_eos_token_id { break; } } - println!( - "{}", - tokenizer_dec.decode(&token_ids, true).map_err(E::msg)? - ); + if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? { + print!("{rest}"); + } + println!(); Ok(()) } |