summaryrefslogtreecommitdiff
path: root/candle-examples/examples/mistral/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/mistral/main.rs')
-rw-r--r--candle-examples/examples/mistral/main.rs30
1 files changed, 20 insertions, 10 deletions
diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs
index e0cecf15..6fe08963 100644
--- a/candle-examples/examples/mistral/main.rs
+++ b/candle-examples/examples/mistral/main.rs
@@ -10,6 +10,7 @@ use clap::Parser;
use candle_transformers::models::mistral::{Config, Model};
use candle::{DType, Device, Tensor};
+use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
@@ -18,7 +19,7 @@ use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
device: Device,
- tokenizer: Tokenizer,
+ tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
@@ -39,7 +40,7 @@ impl TextGeneration {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
- tokenizer,
+ tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
@@ -49,18 +50,24 @@ impl TextGeneration {
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
- println!("starting the inference loop");
- std::io::stdout().flush()?;
+ self.tokenizer.clear();
let mut tokens = self
.tokenizer
+ .tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
+ for &t in tokens.iter() {
+ if let Some(t) = self.tokenizer.next_token(t)? {
+ print!("{t}")
+ }
+ }
+ std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
- let eos_token = match self.tokenizer.get_vocab(true).get("</s>") {
- Some(token) => *token,
+ let eos_token = match self.tokenizer.get_token("</s>") {
+ Some(token) => token,
None => anyhow::bail!("cannot find the </s> token"),
};
let start_gen = std::time::Instant::now();
@@ -88,12 +95,15 @@ impl TextGeneration {
if next_token == eos_token {
break;
}
- // TODO: print the generated tokens in a streaming way. Using `self.tokenizer.decode`
- // on each token seems to swallow the whitespaces.
+ if let Some(t) = self.tokenizer.next_token(next_token)? {
+ print!("{t}");
+ std::io::stdout().flush()?;
+ }
}
let dt = start_gen.elapsed();
- let generated_text = self.tokenizer.decode(&tokens, true).map_err(E::msg)?;
- println!("Generated text:\n{generated_text}");
+ let rest = self.tokenizer.decode_rest().map_err(E::msg)?;
+ print!("{rest}");
+ std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),