summaryrefslogtreecommitdiff
path: root/candle-examples
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-02-15 16:47:33 +0100
committerGitHub <noreply@github.com>2024-02-15 16:47:33 +0100
commit7c7400fb6320737cf345004df12e0733ac99704c (patch)
tree6654ee8825f36274a9fa674688180cb1ef62a081 /candle-examples
parent058a910d0e1e9bb511209b1f756a6cd07a347889 (diff)
downloadcandle-7c7400fb6320737cf345004df12e0733ac99704c.tar.gz
candle-7c7400fb6320737cf345004df12e0733ac99704c.tar.bz2
candle-7c7400fb6320737cf345004df12e0733ac99704c.zip
Use the tokenizer-output-stream in the llama example. (#1715)
* Use the tokenizer-output-stream in the llama example. * Also use tokenizer-output-stream for llama2-c.
Diffstat (limited to 'candle-examples')
-rw-r--r--candle-examples/examples/llama/main.rs20
-rw-r--r--candle-examples/examples/llama2-c/main.rs13
-rw-r--r--candle-examples/examples/mistral/main.rs2
-rw-r--r--candle-examples/examples/mixtral/main.rs2
4 files changed, 17 insertions, 20 deletions
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs
index 251c184b..e95321c7 100644
--- a/candle-examples/examples/llama/main.rs
+++ b/candle-examples/examples/llama/main.rs
@@ -57,7 +57,7 @@ struct Args {
seed: u64,
/// The length of the sample to generate (in tokens).
- #[arg(long, default_value_t = 100)]
+ #[arg(long, default_value_t = 10000)]
sample_len: usize,
/// Disable the key-value cache.
@@ -143,7 +143,6 @@ fn main() -> Result<()> {
}
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
};
- println!("building the model");
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
@@ -157,6 +156,7 @@ fn main() -> Result<()> {
.map_err(E::msg)?
.get_ids()
.to_vec();
+ let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
println!("starting the inference loop");
print!("{prompt}");
@@ -190,18 +190,16 @@ fn main() -> Result<()> {
token_generated += 1;
tokens.push(next_token);
- // Extracting the last token as a string is complicated, here we just apply some simple
- // heuristics as it seems to work well enough for this example. See the following for more
- // details:
- // https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
- if let Some(text) = tokenizer.id_to_token(next_token) {
- let text = text.replace('▁', " ").replace("<0x0A>", "\n");
- print!("{text}");
- std::io::stdout().flush()?;
- }
if Some(next_token) == eos_token_id {
break;
}
+ if let Some(t) = tokenizer.next_token(next_token)? {
+ print!("{t}");
+ std::io::stdout().flush()?;
+ }
+ }
+ if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
+ print!("{rest}");
}
let dt = start_gen.elapsed();
println!(
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs
index 9d42dcc8..27ebc80f 100644
--- a/candle-examples/examples/llama2-c/main.rs
+++ b/candle-examples/examples/llama2-c/main.rs
@@ -328,6 +328,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
.map_err(E::msg)?
.get_ids()
.to_vec();
+ let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
let start_gen = std::time::Instant::now();
for index in 0.. {
@@ -353,16 +354,14 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);
- // Extracting the last token as a string is complicated, here we just apply some simple
- // heuristics as it seems to work well enough for this example. See the following for more
- // details:
- // https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
- if let Some(text) = tokenizer.id_to_token(next_token) {
- let text = text.replace('▁', " ").replace("<0x0A>", "\n");
- print!("{text}");
+ if let Some(t) = tokenizer.next_token(next_token)? {
+ print!("{t}");
std::io::stdout().flush()?;
}
}
+ if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
+ print!("{rest}");
+ }
let dt = start_gen.elapsed();
println!(
"\n{} tokens generated ({:.2} token/s)\n",
diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs
index bad86098..1cf4107c 100644
--- a/candle-examples/examples/mistral/main.rs
+++ b/candle-examples/examples/mistral/main.rs
@@ -152,7 +152,7 @@ struct Args {
seed: u64,
/// The length of the sample to generate (in tokens).
- #[arg(long, short = 'n', default_value_t = 100)]
+ #[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize,
#[arg(long)]
diff --git a/candle-examples/examples/mixtral/main.rs b/candle-examples/examples/mixtral/main.rs
index 1b1a4b36..fe47e537 100644
--- a/candle-examples/examples/mixtral/main.rs
+++ b/candle-examples/examples/mixtral/main.rs
@@ -143,7 +143,7 @@ struct Args {
seed: u64,
/// The length of the sample to generate (in tokens).
- #[arg(long, short = 'n', default_value_t = 100)]
+ #[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize,
#[arg(long, default_value = "mistralai/Mixtral-8x7B-v0.1")]