summaryrefslogtreecommitdiff
path: root/candle-examples/examples/marian-mt/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/marian-mt/main.rs')
-rw-r--r--candle-examples/examples/marian-mt/main.rs23
1 files changed, 14 insertions, 9 deletions
diff --git a/candle-examples/examples/marian-mt/main.rs b/candle-examples/examples/marian-mt/main.rs
index c503667c..89b3a9a3 100644
--- a/candle-examples/examples/marian-mt/main.rs
+++ b/candle-examples/examples/marian-mt/main.rs
@@ -8,6 +8,7 @@ use anyhow::Error as E;
use clap::{Parser, ValueEnum};
use candle::{DType, Tensor};
+use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::models::marian;
@@ -87,6 +88,7 @@ pub fn main() -> anyhow::Result<()> {
};
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
};
+ let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
let device = candle_examples::device(args.cpu)?;
let vb = {
@@ -107,7 +109,7 @@ pub fn main() -> anyhow::Result<()> {
};
unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? }
};
- let model = marian::MTModel::new(&config, vb)?;
+ let mut model = marian::MTModel::new(&config, vb)?;
let mut logits_processor =
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
@@ -125,23 +127,26 @@ pub fn main() -> anyhow::Result<()> {
let mut token_ids = vec![config.decoder_start_token_id];
for index in 0..1000 {
- // TODO: Add a kv cache.
- let context_size = if index >= 1000 { 1 } else { token_ids.len() };
+ let context_size = if index >= 1 { 1 } else { token_ids.len() };
let start_pos = token_ids.len().saturating_sub(context_size);
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
- let logits = model.decode(&input_ids, &encoder_xs)?;
+ let logits = model.decode(&input_ids, &encoder_xs, start_pos)?;
let logits = logits.squeeze(0)?;
let logits = logits.get(logits.dim(0)? - 1)?;
let token = logits_processor.sample(&logits)?;
token_ids.push(token);
- println!("{token}");
+ if let Some(t) = tokenizer_dec.next_token(token)? {
+ use std::io::Write;
+ print!("{t}");
+ std::io::stdout().flush()?;
+ }
if token == config.eos_token_id || token == config.forced_eos_token_id {
break;
}
}
- println!(
- "{}",
- tokenizer_dec.decode(&token_ids, true).map_err(E::msg)?
- );
+ if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? {
+ print!("{rest}");
+ }
+ println!();
Ok(())
}