diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-15 12:18:20 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-15 12:18:20 +0100 |
commit | 5b1690fffa319f87c60ebe3b65e61dc7fa79fe2c (patch) | |
tree | 9b189e5756ade9958227eedc6f1ca9579008fcd5 | |
parent | 3cc87058b7dcd73a723cd6375861743f535c0e21 (diff) | |
download | candle-5b1690fffa319f87c60ebe3b65e61dc7fa79fe2c.tar.gz candle-5b1690fffa319f87c60ebe3b65e61dc7fa79fe2c.tar.bz2 candle-5b1690fffa319f87c60ebe3b65e61dc7fa79fe2c.zip |
Tweak the llama example. (#450)
-rw-r--r-- | candle-examples/examples/llama/main.rs | 77 |
1 files changed, 14 insertions, 63 deletions
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index 9cb9d91d..98ff9cca 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -19,62 +19,14 @@ use candle::{DType, Tensor}; use candle_nn::VarBuilder; use candle_transformers::generation::LogitsProcessor; use hf_hub::api::sync::Api; +use std::io::Write; mod model; use model::{Config, Llama}; +const EOS_TOKEN: &str = "</s>"; const MAX_SEQ_LEN: usize = 4096; -const DEFAULT_PROMPT: &str = r" -EDWARD: -I wonder how our princely father 'scaped, -Or whether he be 'scaped away or no -From Clifford's and Northumberland's pursuit: -Had he been ta'en, we should have heard the news; -Had he been slain, we should have heard the news; -Or had he 'scaped, methinks we should have heard -The happy tidings of his good escape. -How fares my brother? why is he so sad? - -RICHARD: -I cannot joy, until I be resolved -Where our right valiant father is become. -I saw him in the battle range about; -And watch'd him how he singled Clifford forth. -Methought he bore him in the thickest troop -As doth a lion in a herd of neat; -Or as a bear, encompass'd round with dogs, -Who having pinch'd a few and made them cry, -The rest stand all aloof, and bark at him. -So fared our father with his enemies; -So fled his enemies my warlike father: -Methinks, 'tis prize enough to be his son. -See how the morning opes her golden gates, -And takes her farewell of the glorious sun! -How well resembles it the prime of youth, -Trimm'd like a younker prancing to his love! - -EDWARD: -Dazzle mine eyes, or do I see three suns? - -RICHARD: -Three glorious suns, each one a perfect sun; -Not separated with the racking clouds, -But sever'd in a pale clear-shining sky. -See, see! they join, embrace, and seem to kiss, -As if they vow'd some league inviolable: -Now are they but one lamp, one light, one sun. -In this the heaven figures some event. - -EDWARD: -'Tis wondrous strange, the like yet never heard of. -I think it cites us, brother, to the field, -That we, the sons of brave Plantagenet, -Each one already blazing by our meeds, -Should notwithstanding join our lights together -And over-shine the earth as this the world. -Whate'er it bodes, henceforward will I bear -Upon my target three fair-shining suns. -"; +const DEFAULT_PROMPT: &str = "My favorite theorem is "; #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] @@ -207,6 +159,7 @@ fn main() -> Result<()> { } }; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + let eos_token_id = tokenizer.token_to_id(EOS_TOKEN); let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str()); let mut tokens = tokenizer .encode(prompt, true) @@ -215,8 +168,8 @@ fn main() -> Result<()> { .to_vec(); println!("starting the inference loop"); + print!("{prompt}"); let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature); - let mut new_tokens = vec![]; let start_gen = std::time::Instant::now(); let mut index_pos = 0; let mut token_generated = 0; @@ -235,19 +188,17 @@ fn main() -> Result<()> { let next_token = logits_processor.sample(&logits)?; token_generated += 1; tokens.push(next_token); - new_tokens.push(next_token); - let tk = tokenizer.decode(&[next_token], true).map_err(E::msg)?; - if [",", ".", ":", "?", "'", "\""].contains(&tk.as_str()) - || index == args.sample_len - 1 - || next_token == 2 - { - //2 for end token - print!("{} ", tokenizer.decode(&new_tokens, true).map_err(E::msg)?); - new_tokens.clear(); + // Extracting the last token as a string is complicated, here we just apply some simple + // heuristics as it seems to work well enough for this example. See the following for more + // details: + // https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141 + if let Some(text) = tokenizer.id_to_token(next_token) { + let text = text.replace('▁', " ").replace("<0x0A>", "\n"); + print!("{text}"); + std::io::stdout().flush()?; } - - if next_token == 2 { + if Some(next_token) == eos_token_id { break; } } |