diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-30 16:04:11 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-30 15:04:11 +0100 |
commit | 06207332bc58e20680dd1925b7d90bac51f4f21c (patch) | |
tree | 21602b4923638fcea5463cf02567865a8a7b3806 /candle-examples/src/token_output_stream.rs | |
parent | 4021272875605aa8a7196a5d08ccf901c3ea6c4b (diff) | |
download | candle-06207332bc58e20680dd1925b7d90bac51f4f21c.tar.gz candle-06207332bc58e20680dd1925b7d90bac51f4f21c.tar.bz2 candle-06207332bc58e20680dd1925b7d90bac51f4f21c.zip |
Streaming mode for reporting the generated tokens (#1007)
* Token streaming.
* Use the token output stream.
* Flush the output.
* Ensure that the last characters get reported.
Diffstat (limited to 'candle-examples/src/token_output_stream.rs')
-rw-r--r-- | candle-examples/src/token_output_stream.rs | 74 |
1 files changed, 74 insertions, 0 deletions
diff --git a/candle-examples/src/token_output_stream.rs b/candle-examples/src/token_output_stream.rs new file mode 100644 index 00000000..3d975d63 --- /dev/null +++ b/candle-examples/src/token_output_stream.rs @@ -0,0 +1,74 @@ +use candle::Result; + +/// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a +/// streaming way rather than having to wait for the full decoding. +pub struct TokenOutputStream { + tokenizer: tokenizers::Tokenizer, + tokens: Vec<u32>, + prev_index: usize, + current_index: usize, +} + +impl TokenOutputStream { + pub fn new(tokenizer: tokenizers::Tokenizer) -> Self { + Self { + tokenizer, + tokens: Vec::new(), + prev_index: 0, + current_index: 0, + } + } + + pub fn into_inner(self) -> tokenizers::Tokenizer { + self.tokenizer + } + + fn decode(&self, tokens: &[u32]) -> Result<String> { + match self.tokenizer.decode(tokens, true) { + Ok(str) => Ok(str), + Err(err) => candle::bail!("cannot decode: {err}"), + } + } + + // https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68 + pub fn next_token(&mut self, token: u32) -> Result<Option<String>> { + let prev_text = if self.tokens.is_empty() { + String::new() + } else { + let tokens = &self.tokens[self.prev_index..self.current_index]; + self.decode(tokens)? + }; + self.tokens.push(token); + let text = self.decode(&self.tokens[self.prev_index..])?; + if text.len() > prev_text.len() && text.chars().last().unwrap().is_ascii() { + let text = text.split_at(prev_text.len()); + self.prev_index = self.current_index; + self.current_index = self.tokens.len(); + Ok(Some(text.1.to_string())) + } else { + Ok(None) + } + } + + pub fn decode_rest(&self) -> Result<String> { + self.decode(&self.tokens[self.prev_index..]) + } + + pub fn decode_all(&self) -> Result<String> { + self.decode(&self.tokens) + } + + pub fn get_token(&self, token_s: &str) -> Option<u32> { + self.tokenizer.get_vocab(true).get(token_s).copied() + } + + pub fn tokenizer(&self) -> &tokenizers::Tokenizer { + &self.tokenizer + } + + pub fn clear(&mut self) { + self.tokens.clear(); + self.prev_index = 0; + self.current_index = 0; + } +} |