summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMahmoud <botchm@hotmail.com>2023-09-20 00:54:56 -0700
committerGitHub <noreply@github.com>2023-09-20 08:54:56 +0100
commit098dd0d1e9cc2b1ca902e4e0d77a9abe3de72a9c (patch)
tree8ab42a64b89659445dc0df2398448dabb714aa75
parent05626ef492300a4f99c87555304ec863071722d5 (diff)
downloadcandle-098dd0d1e9cc2b1ca902e4e0d77a9abe3de72a9c.tar.gz
candle-098dd0d1e9cc2b1ca902e4e0d77a9abe3de72a9c.tar.bz2
candle-098dd0d1e9cc2b1ca902e4e0d77a9abe3de72a9c.zip
fix: add missing`top_p` in llama_multiprocess (#905)
-rw-r--r--candle-examples/examples/llama_multiprocess/main.rs6
1 files changed, 5 insertions, 1 deletions
diff --git a/candle-examples/examples/llama_multiprocess/main.rs b/candle-examples/examples/llama_multiprocess/main.rs
index 17dc90e2..8a13ce6c 100644
--- a/candle-examples/examples/llama_multiprocess/main.rs
+++ b/candle-examples/examples/llama_multiprocess/main.rs
@@ -89,6 +89,10 @@ struct Args {
#[arg(long)]
temperature: Option<f64>,
+ /// Nucleus sampling probability cutoff.
+ #[arg(long)]
+ top_p: Option<f64>,
+
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
@@ -222,7 +226,7 @@ fn main() -> Result<()> {
.to_vec();
println!("starting the inference loop");
- let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
+ let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
let mut new_tokens = vec![];
let start_gen = std::time::Instant::now();
let mut index_pos = 0;