diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-02-15 16:47:33 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-15 16:47:33 +0100 |
commit | 7c7400fb6320737cf345004df12e0733ac99704c (patch) | |
tree | 6654ee8825f36274a9fa674688180cb1ef62a081 /candle-examples | |
parent | 058a910d0e1e9bb511209b1f756a6cd07a347889 (diff) | |
download | candle-7c7400fb6320737cf345004df12e0733ac99704c.tar.gz candle-7c7400fb6320737cf345004df12e0733ac99704c.tar.bz2 candle-7c7400fb6320737cf345004df12e0733ac99704c.zip |
Use the tokenizer-output-stream in the llama example. (#1715)
* Use the tokenizer-output-stream in the llama example.
* Also use tokenizer-output-stream for llama2-c.
Diffstat (limited to 'candle-examples')
-rw-r--r-- | candle-examples/examples/llama/main.rs | 20 | ||||
-rw-r--r-- | candle-examples/examples/llama2-c/main.rs | 13 | ||||
-rw-r--r-- | candle-examples/examples/mistral/main.rs | 2 | ||||
-rw-r--r-- | candle-examples/examples/mixtral/main.rs | 2 |
4 files changed, 17 insertions, 20 deletions
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index 251c184b..e95321c7 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -57,7 +57,7 @@ struct Args { seed: u64, /// The length of the sample to generate (in tokens). - #[arg(long, default_value_t = 100)] + #[arg(long, default_value_t = 10000)] sample_len: usize, /// Disable the key-value cache. @@ -143,7 +143,6 @@ fn main() -> Result<()> { } Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?], }; - println!("building the model"); let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; @@ -157,6 +156,7 @@ fn main() -> Result<()> { .map_err(E::msg)? .get_ids() .to_vec(); + let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer); println!("starting the inference loop"); print!("{prompt}"); @@ -190,18 +190,16 @@ fn main() -> Result<()> { token_generated += 1; tokens.push(next_token); - // Extracting the last token as a string is complicated, here we just apply some simple - // heuristics as it seems to work well enough for this example. See the following for more - // details: - // https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141 - if let Some(text) = tokenizer.id_to_token(next_token) { - let text = text.replace('▁', " ").replace("<0x0A>", "\n"); - print!("{text}"); - std::io::stdout().flush()?; - } if Some(next_token) == eos_token_id { break; } + if let Some(t) = tokenizer.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + } + if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? { + print!("{rest}"); } let dt = start_gen.elapsed(); println!( diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index 9d42dcc8..27ebc80f 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -328,6 +328,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { .map_err(E::msg)? .get_ids() .to_vec(); + let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer); let start_gen = std::time::Instant::now(); for index in 0.. { @@ -353,16 +354,14 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { let next_token = logits_processor.sample(&logits)?; tokens.push(next_token); - // Extracting the last token as a string is complicated, here we just apply some simple - // heuristics as it seems to work well enough for this example. See the following for more - // details: - // https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141 - if let Some(text) = tokenizer.id_to_token(next_token) { - let text = text.replace('▁', " ").replace("<0x0A>", "\n"); - print!("{text}"); + if let Some(t) = tokenizer.next_token(next_token)? { + print!("{t}"); std::io::stdout().flush()?; } } + if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? { + print!("{rest}"); + } let dt = start_gen.elapsed(); println!( "\n{} tokens generated ({:.2} token/s)\n", diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs index bad86098..1cf4107c 100644 --- a/candle-examples/examples/mistral/main.rs +++ b/candle-examples/examples/mistral/main.rs @@ -152,7 +152,7 @@ struct Args { seed: u64, /// The length of the sample to generate (in tokens). - #[arg(long, short = 'n', default_value_t = 100)] + #[arg(long, short = 'n', default_value_t = 10000)] sample_len: usize, #[arg(long)] diff --git a/candle-examples/examples/mixtral/main.rs b/candle-examples/examples/mixtral/main.rs index 1b1a4b36..fe47e537 100644 --- a/candle-examples/examples/mixtral/main.rs +++ b/candle-examples/examples/mixtral/main.rs @@ -143,7 +143,7 @@ struct Args { seed: u64, /// The length of the sample to generate (in tokens). - #[arg(long, short = 'n', default_value_t = 100)] + #[arg(long, short = 'n', default_value_t = 10000)] sample_len: usize, #[arg(long, default_value = "mistralai/Mixtral-8x7B-v0.1")] |