summaryrefslogtreecommitdiff
path: root/candle-examples/examples/phi/main.rs
diff options
context:
space:
mode:
authorRadamés Ajna <radamajna@gmail.com>2023-09-26 01:21:22 -0700
committerGitHub <noreply@github.com>2023-09-26 09:21:22 +0100
commit2dd43d6cdd3242bcbe49a0558e56e24549a866d0 (patch)
tree4ae63516c6e537d409bc913dd61e52c252dadc42 /candle-examples/examples/phi/main.rs
parent1fcac4afede215d44c4bf97c8b8c5bad06fcba09 (diff)
downloadcandle-2dd43d6cdd3242bcbe49a0558e56e24549a866d0.tar.gz
candle-2dd43d6cdd3242bcbe49a0558e56e24549a866d0.tar.bz2
candle-2dd43d6cdd3242bcbe49a0558e56e24549a866d0.zip
add eos token to phi example (#965)
* add eos token to phi example * rustfmt + get the token directly. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-examples/examples/phi/main.rs')
-rw-r--r--candle-examples/examples/phi/main.rs7
1 files changed, 7 insertions, 0 deletions
diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs
index fe365e18..ab37ed5f 100644
--- a/candle-examples/examples/phi/main.rs
+++ b/candle-examples/examples/phi/main.rs
@@ -66,6 +66,10 @@ impl TextGeneration {
.to_vec();
let mut new_tokens = vec![];
+ let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
+ Some(token) => *token,
+ None => anyhow::bail!("cannot find the endoftext token"),
+ };
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
@@ -90,6 +94,9 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
new_tokens.push(next_token);
+ if next_token == eos_token {
+ break;
+ }
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
print!("{token}");
std::io::stdout().flush()?;