summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-15 12:18:20 +0100
committerGitHub <noreply@github.com>2023-08-15 12:18:20 +0100
commit5b1690fffa319f87c60ebe3b65e61dc7fa79fe2c (patch)
tree9b189e5756ade9958227eedc6f1ca9579008fcd5
parent3cc87058b7dcd73a723cd6375861743f535c0e21 (diff)
downloadcandle-5b1690fffa319f87c60ebe3b65e61dc7fa79fe2c.tar.gz
candle-5b1690fffa319f87c60ebe3b65e61dc7fa79fe2c.tar.bz2
candle-5b1690fffa319f87c60ebe3b65e61dc7fa79fe2c.zip
Tweak the llama example. (#450)
-rw-r--r--candle-examples/examples/llama/main.rs77
1 files changed, 14 insertions, 63 deletions
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs
index 9cb9d91d..98ff9cca 100644
--- a/candle-examples/examples/llama/main.rs
+++ b/candle-examples/examples/llama/main.rs
@@ -19,62 +19,14 @@ use candle::{DType, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::api::sync::Api;
+use std::io::Write;
mod model;
use model::{Config, Llama};
+const EOS_TOKEN: &str = "</s>";
const MAX_SEQ_LEN: usize = 4096;
-const DEFAULT_PROMPT: &str = r"
-EDWARD:
-I wonder how our princely father 'scaped,
-Or whether he be 'scaped away or no
-From Clifford's and Northumberland's pursuit:
-Had he been ta'en, we should have heard the news;
-Had he been slain, we should have heard the news;
-Or had he 'scaped, methinks we should have heard
-The happy tidings of his good escape.
-How fares my brother? why is he so sad?
-
-RICHARD:
-I cannot joy, until I be resolved
-Where our right valiant father is become.
-I saw him in the battle range about;
-And watch'd him how he singled Clifford forth.
-Methought he bore him in the thickest troop
-As doth a lion in a herd of neat;
-Or as a bear, encompass'd round with dogs,
-Who having pinch'd a few and made them cry,
-The rest stand all aloof, and bark at him.
-So fared our father with his enemies;
-So fled his enemies my warlike father:
-Methinks, 'tis prize enough to be his son.
-See how the morning opes her golden gates,
-And takes her farewell of the glorious sun!
-How well resembles it the prime of youth,
-Trimm'd like a younker prancing to his love!
-
-EDWARD:
-Dazzle mine eyes, or do I see three suns?
-
-RICHARD:
-Three glorious suns, each one a perfect sun;
-Not separated with the racking clouds,
-But sever'd in a pale clear-shining sky.
-See, see! they join, embrace, and seem to kiss,
-As if they vow'd some league inviolable:
-Now are they but one lamp, one light, one sun.
-In this the heaven figures some event.
-
-EDWARD:
-'Tis wondrous strange, the like yet never heard of.
-I think it cites us, brother, to the field,
-That we, the sons of brave Plantagenet,
-Each one already blazing by our meeds,
-Should notwithstanding join our lights together
-And over-shine the earth as this the world.
-Whate'er it bodes, henceforward will I bear
-Upon my target three fair-shining suns.
-";
+const DEFAULT_PROMPT: &str = "My favorite theorem is ";
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
@@ -207,6 +159,7 @@ fn main() -> Result<()> {
}
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
+ let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
let mut tokens = tokenizer
.encode(prompt, true)
@@ -215,8 +168,8 @@ fn main() -> Result<()> {
.to_vec();
println!("starting the inference loop");
+ print!("{prompt}");
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
- let mut new_tokens = vec![];
let start_gen = std::time::Instant::now();
let mut index_pos = 0;
let mut token_generated = 0;
@@ -235,19 +188,17 @@ fn main() -> Result<()> {
let next_token = logits_processor.sample(&logits)?;
token_generated += 1;
tokens.push(next_token);
- new_tokens.push(next_token);
- let tk = tokenizer.decode(&[next_token], true).map_err(E::msg)?;
- if [",", ".", ":", "?", "'", "\""].contains(&tk.as_str())
- || index == args.sample_len - 1
- || next_token == 2
- {
- //2 for end token
- print!("{} ", tokenizer.decode(&new_tokens, true).map_err(E::msg)?);
- new_tokens.clear();
+ // Extracting the last token as a string is complicated, here we just apply some simple
+ // heuristics as it seems to work well enough for this example. See the following for more
+ // details:
+ // https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
+ if let Some(text) = tokenizer.id_to_token(next_token) {
+ let text = text.replace('▁', " ").replace("<0x0A>", "\n");
+ print!("{text}");
+ std::io::stdout().flush()?;
}
-
- if next_token == 2 {
+ if Some(next_token) == eos_token_id {
break;
}
}