diff options
Diffstat (limited to 'candle-examples/examples/mistral/main.rs')
-rw-r--r-- | candle-examples/examples/mistral/main.rs | 30 |
1 files changed, 20 insertions, 10 deletions
diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs index e0cecf15..6fe08963 100644 --- a/candle-examples/examples/mistral/main.rs +++ b/candle-examples/examples/mistral/main.rs @@ -10,6 +10,7 @@ use clap::Parser; use candle_transformers::models::mistral::{Config, Model}; use candle::{DType, Device, Tensor}; +use candle_examples::token_output_stream::TokenOutputStream; use candle_nn::VarBuilder; use candle_transformers::generation::LogitsProcessor; use hf_hub::{api::sync::Api, Repo, RepoType}; @@ -18,7 +19,7 @@ use tokenizers::Tokenizer; struct TextGeneration { model: Model, device: Device, - tokenizer: Tokenizer, + tokenizer: TokenOutputStream, logits_processor: LogitsProcessor, repeat_penalty: f32, repeat_last_n: usize, @@ -39,7 +40,7 @@ impl TextGeneration { let logits_processor = LogitsProcessor::new(seed, temp, top_p); Self { model, - tokenizer, + tokenizer: TokenOutputStream::new(tokenizer), logits_processor, repeat_penalty, repeat_last_n, @@ -49,18 +50,24 @@ impl TextGeneration { fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> { use std::io::Write; - println!("starting the inference loop"); - std::io::stdout().flush()?; + self.tokenizer.clear(); let mut tokens = self .tokenizer + .tokenizer() .encode(prompt, true) .map_err(E::msg)? .get_ids() .to_vec(); + for &t in tokens.iter() { + if let Some(t) = self.tokenizer.next_token(t)? { + print!("{t}") + } + } + std::io::stdout().flush()?; let mut generated_tokens = 0usize; - let eos_token = match self.tokenizer.get_vocab(true).get("</s>") { - Some(token) => *token, + let eos_token = match self.tokenizer.get_token("</s>") { + Some(token) => token, None => anyhow::bail!("cannot find the </s> token"), }; let start_gen = std::time::Instant::now(); @@ -88,12 +95,15 @@ impl TextGeneration { if next_token == eos_token { break; } - // TODO: print the generated tokens in a streaming way. Using `self.tokenizer.decode` - // on each token seems to swallow the whitespaces. + if let Some(t) = self.tokenizer.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } } let dt = start_gen.elapsed(); - let generated_text = self.tokenizer.decode(&tokens, true).map_err(E::msg)?; - println!("Generated text:\n{generated_text}"); + let rest = self.tokenizer.decode_rest().map_err(E::msg)?; + print!("{rest}"); + std::io::stdout().flush()?; println!( "\n{generated_tokens} tokens generated ({:.2} token/s)", generated_tokens as f64 / dt.as_secs_f64(), |