diff options
author | Juarez Bochi <jbochi@gmail.com> | 2023-09-12 09:10:16 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-12 18:10:16 +0200 |
commit | 805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f (patch) | |
tree | 0df65e2e6fee356d2345954701ec3d47796ae7ee /candle-examples/examples/llama | |
parent | 42da17694a4214a3e39e0d64afc22635ce83f557 (diff) | |
download | candle-805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f.tar.gz candle-805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f.tar.bz2 candle-805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f.zip |
Implement top_p / nucleus sampling (#819)
* Implement top_p / nucleus sampling
* Update changelog
* rustfmt
* Add tests
* Fix clippy warning
* Fix another clippy error
Diffstat (limited to 'candle-examples/examples/llama')
-rw-r--r-- | candle-examples/examples/llama/main.rs | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index db3d216c..b2d7d938 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -42,6 +42,10 @@ struct Args { #[arg(long)] temperature: Option<f64>, + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option<f64>, + /// The seed to use when generating random samples. #[arg(long, default_value_t = 299792458)] seed: u64, @@ -193,7 +197,7 @@ fn main() -> Result<()> { println!("starting the inference loop"); print!("{prompt}"); - let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature); + let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p); let start_gen = std::time::Instant::now(); let mut index_pos = 0; let mut token_generated = 0; |