diff options
author | Juarez Bochi <jbochi@gmail.com> | 2023-09-12 09:10:16 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-12 18:10:16 +0200 |
commit | 805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f (patch) | |
tree | 0df65e2e6fee356d2345954701ec3d47796ae7ee /candle-examples/examples | |
parent | 42da17694a4214a3e39e0d64afc22635ce83f557 (diff) | |
download | candle-805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f.tar.gz candle-805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f.tar.bz2 candle-805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f.zip |
Implement top_p / nucleus sampling (#819)
* Implement top_p / nucleus sampling
* Update changelog
* rustfmt
* Add tests
* Fix clippy warning
* Fix another clippy error
Diffstat (limited to 'candle-examples/examples')
-rw-r--r-- | candle-examples/examples/bigcode/main.rs | 16 | ||||
-rw-r--r-- | candle-examples/examples/falcon/main.rs | 37 | ||||
-rw-r--r-- | candle-examples/examples/llama/main.rs | 6 | ||||
-rw-r--r-- | candle-examples/examples/llama2-c/main.rs | 7 | ||||
-rw-r--r-- | candle-examples/examples/quantized/main.rs | 6 |
5 files changed, 54 insertions, 18 deletions
diff --git a/candle-examples/examples/bigcode/main.rs b/candle-examples/examples/bigcode/main.rs index 3540f75d..5f17109e 100644 --- a/candle-examples/examples/bigcode/main.rs +++ b/candle-examples/examples/bigcode/main.rs @@ -28,9 +28,10 @@ impl TextGeneration { tokenizer: Tokenizer, seed: u64, temp: Option<f64>, + top_p: Option<f64>, device: &Device, ) -> Self { - let logits_processor = LogitsProcessor::new(seed, temp); + let logits_processor = LogitsProcessor::new(seed, temp, top_p); Self { model, tokenizer, @@ -94,6 +95,10 @@ struct Args { #[arg(long)] temperature: Option<f64>, + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option<f64>, + /// The seed to use when generating random samples. #[arg(long, default_value_t = 299792458)] seed: u64, @@ -149,7 +154,14 @@ fn main() -> Result<()> { let model = GPTBigCode::load(vb, config)?; println!("loaded the model in {:?}", start.elapsed()); - let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, args.temperature, &device); + let mut pipeline = TextGeneration::new( + model, + tokenizer, + args.seed, + args.temperature, + args.top_p, + &device, + ); pipeline.run(&args.prompt, args.sample_len)?; Ok(()) } diff --git a/candle-examples/examples/falcon/main.rs b/candle-examples/examples/falcon/main.rs index c45fe545..b0973d64 100644 --- a/candle-examples/examples/falcon/main.rs +++ b/candle-examples/examples/falcon/main.rs @@ -25,17 +25,25 @@ struct TextGeneration { repeat_last_n: usize, } +struct GenerationOptions { + temp: Option<f64>, + top_p: Option<f64>, + repeat_penalty: f32, + repeat_last_n: usize, +} + impl TextGeneration { fn new( model: Falcon, tokenizer: Tokenizer, + generation_options: GenerationOptions, seed: u64, - temp: Option<f64>, device: &Device, - repeat_penalty: f32, - repeat_last_n: usize, ) -> Self { - let logits_processor = LogitsProcessor::new(seed, temp); + let logits_processor = + LogitsProcessor::new(seed, generation_options.temp, generation_options.top_p); + let repeat_penalty = generation_options.repeat_penalty; + let repeat_last_n = generation_options.repeat_last_n; Self { model, tokenizer, @@ -118,6 +126,10 @@ struct Args { #[arg(long)] temperature: Option<f64>, + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option<f64>, + /// The seed to use when generating random samples. #[arg(long, default_value_t = 299792458)] seed: u64, @@ -185,15 +197,14 @@ fn main() -> Result<()> { let model = Falcon::load(vb, config)?; println!("loaded the model in {:?}", start.elapsed()); - let mut pipeline = TextGeneration::new( - model, - tokenizer, - args.seed, - args.temperature, - &device, - args.repeat_penalty, - args.repeat_last_n, - ); + let generation_options = GenerationOptions { + temp: args.temperature, + top_p: args.top_p, + repeat_penalty: args.repeat_penalty, + repeat_last_n: args.repeat_last_n, + }; + let mut pipeline = + TextGeneration::new(model, tokenizer, generation_options, args.seed, &device); pipeline.run(&args.prompt, args.sample_len)?; Ok(()) } diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index db3d216c..b2d7d938 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -42,6 +42,10 @@ struct Args { #[arg(long)] temperature: Option<f64>, + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option<f64>, + /// The seed to use when generating random samples. #[arg(long, default_value_t = 299792458)] seed: u64, @@ -193,7 +197,7 @@ fn main() -> Result<()> { println!("starting the inference loop"); print!("{prompt}"); - let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature); + let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p); let start_gen = std::time::Instant::now(); let mut index_pos = 0; let mut token_generated = 0; diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index e0ade322..e752a494 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -27,6 +27,10 @@ struct InferenceCmd { #[arg(long)] temperature: Option<f64>, + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option<f64>, + #[arg(long, default_value = "")] prompt: String, @@ -133,6 +137,7 @@ fn main() -> anyhow::Result<()> { None => { let cmd = InferenceCmd { temperature: None, + top_p: None, prompt: "".to_string(), config: None, model_id: "karpathy/tinyllamas".to_string(), @@ -256,7 +261,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { let model = Llama::load(vb, &cache, config)?; println!("starting the inference loop"); - let mut logits_processor = LogitsProcessor::new(299792458, args.temperature); + let mut logits_processor = LogitsProcessor::new(299792458, args.temperature, args.top_p); let mut index_pos = 0; print!("{}", args.prompt); diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index c8179d33..a80ad420 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -71,6 +71,10 @@ struct Args { #[arg(long, default_value_t = 0.8)] temperature: f64, + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option<f64>, + /// The seed to use when generating random samples. #[arg(long, default_value_t = 299792458)] seed: u64, @@ -310,7 +314,7 @@ fn main() -> anyhow::Result<()> { prompt_tokens }; let mut all_tokens = vec![]; - let mut logits_processor = LogitsProcessor::new(args.seed, temperature); + let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p); let start_prompt_processing = std::time::Instant::now(); let mut next_token = { |