diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-30 16:04:11 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-30 15:04:11 +0100 |
commit | 06207332bc58e20680dd1925b7d90bac51f4f21c (patch) | |
tree | 21602b4923638fcea5463cf02567865a8a7b3806 | |
parent | 4021272875605aa8a7196a5d08ccf901c3ea6c4b (diff) | |
download | candle-06207332bc58e20680dd1925b7d90bac51f4f21c.tar.gz candle-06207332bc58e20680dd1925b7d90bac51f4f21c.tar.bz2 candle-06207332bc58e20680dd1925b7d90bac51f4f21c.zip |
Streaming mode for reporting the generated tokens (#1007)
* Token streaming.
* Use the token output stream.
* Flush the output.
* Ensure that the last characters get reported.
-rw-r--r-- | candle-examples/Cargo.toml | 2 | ||||
-rw-r--r-- | candle-examples/examples/mistral/main.rs | 30 | ||||
-rw-r--r-- | candle-examples/src/lib.rs | 1 | ||||
-rw-r--r-- | candle-examples/src/token_output_stream.rs | 74 |
4 files changed, 96 insertions, 11 deletions
diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 171415cf..ad2bbf39 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -25,6 +25,7 @@ rayon = { workspace = true } safetensors = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } +tokenizers = { workspace = true, features = ["onig"] } [dev-dependencies] anyhow = { workspace = true } @@ -35,7 +36,6 @@ imageproc = { workspace = true } memmap2 = { workspace = true } rand = { workspace = true } rusttype = { workspace = true } -tokenizers = { workspace = true, features = ["onig"] } tracing = { workspace = true } tracing-chrome = { workspace = true } tracing-subscriber = { workspace = true } diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs index e0cecf15..6fe08963 100644 --- a/candle-examples/examples/mistral/main.rs +++ b/candle-examples/examples/mistral/main.rs @@ -10,6 +10,7 @@ use clap::Parser; use candle_transformers::models::mistral::{Config, Model}; use candle::{DType, Device, Tensor}; +use candle_examples::token_output_stream::TokenOutputStream; use candle_nn::VarBuilder; use candle_transformers::generation::LogitsProcessor; use hf_hub::{api::sync::Api, Repo, RepoType}; @@ -18,7 +19,7 @@ use tokenizers::Tokenizer; struct TextGeneration { model: Model, device: Device, - tokenizer: Tokenizer, + tokenizer: TokenOutputStream, logits_processor: LogitsProcessor, repeat_penalty: f32, repeat_last_n: usize, @@ -39,7 +40,7 @@ impl TextGeneration { let logits_processor = LogitsProcessor::new(seed, temp, top_p); Self { model, - tokenizer, + tokenizer: TokenOutputStream::new(tokenizer), logits_processor, repeat_penalty, repeat_last_n, @@ -49,18 +50,24 @@ impl TextGeneration { fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> { use std::io::Write; - println!("starting the inference loop"); - std::io::stdout().flush()?; + self.tokenizer.clear(); let mut tokens = self .tokenizer + .tokenizer() .encode(prompt, true) .map_err(E::msg)? .get_ids() .to_vec(); + for &t in tokens.iter() { + if let Some(t) = self.tokenizer.next_token(t)? { + print!("{t}") + } + } + std::io::stdout().flush()?; let mut generated_tokens = 0usize; - let eos_token = match self.tokenizer.get_vocab(true).get("</s>") { - Some(token) => *token, + let eos_token = match self.tokenizer.get_token("</s>") { + Some(token) => token, None => anyhow::bail!("cannot find the </s> token"), }; let start_gen = std::time::Instant::now(); @@ -88,12 +95,15 @@ impl TextGeneration { if next_token == eos_token { break; } - // TODO: print the generated tokens in a streaming way. Using `self.tokenizer.decode` - // on each token seems to swallow the whitespaces. + if let Some(t) = self.tokenizer.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } } let dt = start_gen.elapsed(); - let generated_text = self.tokenizer.decode(&tokens, true).map_err(E::msg)?; - println!("Generated text:\n{generated_text}"); + let rest = self.tokenizer.decode_rest().map_err(E::msg)?; + print!("{rest}"); + std::io::stdout().flush()?; println!( "\n{generated_tokens} tokens generated ({:.2} token/s)", generated_tokens as f64 / dt.as_secs_f64(), diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs index 5e0b44fb..4ef97f88 100644 --- a/candle-examples/src/lib.rs +++ b/candle-examples/src/lib.rs @@ -1,5 +1,6 @@ pub mod coco_classes; pub mod imagenet; +pub mod token_output_stream; use candle::{Device, Result, Tensor}; diff --git a/candle-examples/src/token_output_stream.rs b/candle-examples/src/token_output_stream.rs new file mode 100644 index 00000000..3d975d63 --- /dev/null +++ b/candle-examples/src/token_output_stream.rs @@ -0,0 +1,74 @@ +use candle::Result; + +/// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a +/// streaming way rather than having to wait for the full decoding. +pub struct TokenOutputStream { + tokenizer: tokenizers::Tokenizer, + tokens: Vec<u32>, + prev_index: usize, + current_index: usize, +} + +impl TokenOutputStream { + pub fn new(tokenizer: tokenizers::Tokenizer) -> Self { + Self { + tokenizer, + tokens: Vec::new(), + prev_index: 0, + current_index: 0, + } + } + + pub fn into_inner(self) -> tokenizers::Tokenizer { + self.tokenizer + } + + fn decode(&self, tokens: &[u32]) -> Result<String> { + match self.tokenizer.decode(tokens, true) { + Ok(str) => Ok(str), + Err(err) => candle::bail!("cannot decode: {err}"), + } + } + + // https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68 + pub fn next_token(&mut self, token: u32) -> Result<Option<String>> { + let prev_text = if self.tokens.is_empty() { + String::new() + } else { + let tokens = &self.tokens[self.prev_index..self.current_index]; + self.decode(tokens)? + }; + self.tokens.push(token); + let text = self.decode(&self.tokens[self.prev_index..])?; + if text.len() > prev_text.len() && text.chars().last().unwrap().is_ascii() { + let text = text.split_at(prev_text.len()); + self.prev_index = self.current_index; + self.current_index = self.tokens.len(); + Ok(Some(text.1.to_string())) + } else { + Ok(None) + } + } + + pub fn decode_rest(&self) -> Result<String> { + self.decode(&self.tokens[self.prev_index..]) + } + + pub fn decode_all(&self) -> Result<String> { + self.decode(&self.tokens) + } + + pub fn get_token(&self, token_s: &str) -> Option<u32> { + self.tokenizer.get_vocab(true).get(token_s).copied() + } + + pub fn tokenizer(&self) -> &tokenizers::Tokenizer { + &self.tokenizer + } + + pub fn clear(&mut self) { + self.tokens.clear(); + self.prev_index = 0; + self.current_index = 0; + } +} |