summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama
diff options
context:
space:
mode:
authorJuarez Bochi <jbochi@gmail.com>2023-09-12 09:10:16 -0700
committerGitHub <noreply@github.com>2023-09-12 18:10:16 +0200
commit805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f (patch)
tree0df65e2e6fee356d2345954701ec3d47796ae7ee /candle-examples/examples/llama
parent42da17694a4214a3e39e0d64afc22635ce83f557 (diff)
downloadcandle-805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f.tar.gz
candle-805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f.tar.bz2
candle-805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f.zip
Implement top_p / nucleus sampling (#819)
* Implement top_p / nucleus sampling * Update changelog * rustfmt * Add tests * Fix clippy warning * Fix another clippy error
Diffstat (limited to 'candle-examples/examples/llama')
-rw-r--r--candle-examples/examples/llama/main.rs6
1 files changed, 5 insertions, 1 deletions
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs
index db3d216c..b2d7d938 100644
--- a/candle-examples/examples/llama/main.rs
+++ b/candle-examples/examples/llama/main.rs
@@ -42,6 +42,10 @@ struct Args {
#[arg(long)]
temperature: Option<f64>,
+ /// Nucleus sampling probability cutoff.
+ #[arg(long)]
+ top_p: Option<f64>,
+
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
@@ -193,7 +197,7 @@ fn main() -> Result<()> {
println!("starting the inference loop");
print!("{prompt}");
- let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
+ let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
let start_gen = std::time::Instant::now();
let mut index_pos = 0;
let mut token_generated = 0;