summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama_multiprocess/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/llama_multiprocess/main.rs')
-rw-r--r--candle-examples/examples/llama_multiprocess/main.rs28
1 files changed, 16 insertions, 12 deletions
diff --git a/candle-examples/examples/llama_multiprocess/main.rs b/candle-examples/examples/llama_multiprocess/main.rs
index 22c121dd..f9e87432 100644
--- a/candle-examples/examples/llama_multiprocess/main.rs
+++ b/candle-examples/examples/llama_multiprocess/main.rs
@@ -247,20 +247,24 @@ fn main() -> Result<()> {
let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);
new_tokens.push(next_token);
- println!("> {:?}", start_gen.elapsed());
+ if rank == 0 {
+ println!("> {:?}", start_gen.elapsed());
+ println!(
+ "{} token: {} '{}'",
+ index + 1,
+ next_token,
+ tokenizer.decode(vec![next_token], true).map_err(E::msg)?
+ );
+ }
+ }
+ let dt = start_gen.elapsed();
+ if rank == 0 {
println!(
- "{} token: {} '{}'",
- index + 1,
- next_token,
- tokenizer.decode(vec![next_token], true).map_err(E::msg)?
+ "{} tokens generated ({} token/s)\n----\n{}\n----",
+ args.sample_len,
+ args.sample_len as f64 / dt.as_secs_f64(),
+ tokenizer.decode(new_tokens, true).map_err(E::msg)?
);
}
- let dt = start_gen.elapsed();
- println!(
- "{} tokens generated ({} token/s)\n----\n{}\n----",
- args.sample_len,
- args.sample_len as f64 / dt.as_secs_f64(),
- tokenizer.decode(new_tokens, true).map_err(E::msg)?
- );
Ok(())
}