diff options
author | Radamés Ajna <radamajna@gmail.com> | 2023-09-26 01:21:22 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-26 09:21:22 +0100 |
commit | 2dd43d6cdd3242bcbe49a0558e56e24549a866d0 (patch) | |
tree | 4ae63516c6e537d409bc913dd61e52c252dadc42 /candle-examples/examples/phi/main.rs | |
parent | 1fcac4afede215d44c4bf97c8b8c5bad06fcba09 (diff) | |
download | candle-2dd43d6cdd3242bcbe49a0558e56e24549a866d0.tar.gz candle-2dd43d6cdd3242bcbe49a0558e56e24549a866d0.tar.bz2 candle-2dd43d6cdd3242bcbe49a0558e56e24549a866d0.zip |
add eos token to phi example (#965)
* add eos token to phi example
* rustfmt + get the token directly.
---------
Co-authored-by: laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-examples/examples/phi/main.rs')
-rw-r--r-- | candle-examples/examples/phi/main.rs | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index fe365e18..ab37ed5f 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -66,6 +66,10 @@ impl TextGeneration { .to_vec(); let mut new_tokens = vec![]; + let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") { + Some(token) => *token, + None => anyhow::bail!("cannot find the endoftext token"), + }; let start_gen = std::time::Instant::now(); for index in 0..sample_len { let context_size = if index > 0 { 1 } else { tokens.len() }; @@ -90,6 +94,9 @@ impl TextGeneration { let next_token = self.logits_processor.sample(&logits)?; tokens.push(next_token); new_tokens.push(next_token); + if next_token == eos_token { + break; + } let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?; print!("{token}"); std::io::stdout().flush()?; |