diff options
Diffstat (limited to 'candle-examples/examples/llama2-c/main.rs')
-rw-r--r-- | candle-examples/examples/llama2-c/main.rs | 13 |
1 files changed, 6 insertions, 7 deletions
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index 9d42dcc8..27ebc80f 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -328,6 +328,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { .map_err(E::msg)? .get_ids() .to_vec(); + let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer); let start_gen = std::time::Instant::now(); for index in 0.. { @@ -353,16 +354,14 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { let next_token = logits_processor.sample(&logits)?; tokens.push(next_token); - // Extracting the last token as a string is complicated, here we just apply some simple - // heuristics as it seems to work well enough for this example. See the following for more - // details: - // https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141 - if let Some(text) = tokenizer.id_to_token(next_token) { - let text = text.replace('▁', " ").replace("<0x0A>", "\n"); - print!("{text}"); + if let Some(t) = tokenizer.next_token(next_token)? { + print!("{t}"); std::io::stdout().flush()?; } } + if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? { + print!("{rest}"); + } let dt = start_gen.elapsed(); println!( "\n{} tokens generated ({:.2} token/s)\n", |