summaryrefslogtreecommitdiff
path: root/candle-examples/examples
diff options
context:
space:
mode:
authorJuarez Bochi <jbochi@gmail.com>2023-09-12 09:10:16 -0700
committerGitHub <noreply@github.com>2023-09-12 18:10:16 +0200
commit805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f (patch)
tree0df65e2e6fee356d2345954701ec3d47796ae7ee /candle-examples/examples
parent42da17694a4214a3e39e0d64afc22635ce83f557 (diff)
downloadcandle-805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f.tar.gz
candle-805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f.tar.bz2
candle-805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f.zip
Implement top_p / nucleus sampling (#819)
* Implement top_p / nucleus sampling * Update changelog * rustfmt * Add tests * Fix clippy warning * Fix another clippy error
Diffstat (limited to 'candle-examples/examples')
-rw-r--r--candle-examples/examples/bigcode/main.rs16
-rw-r--r--candle-examples/examples/falcon/main.rs37
-rw-r--r--candle-examples/examples/llama/main.rs6
-rw-r--r--candle-examples/examples/llama2-c/main.rs7
-rw-r--r--candle-examples/examples/quantized/main.rs6
5 files changed, 54 insertions, 18 deletions
diff --git a/candle-examples/examples/bigcode/main.rs b/candle-examples/examples/bigcode/main.rs
index 3540f75d..5f17109e 100644
--- a/candle-examples/examples/bigcode/main.rs
+++ b/candle-examples/examples/bigcode/main.rs
@@ -28,9 +28,10 @@ impl TextGeneration {
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
+ top_p: Option<f64>,
device: &Device,
) -> Self {
- let logits_processor = LogitsProcessor::new(seed, temp);
+ let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer,
@@ -94,6 +95,10 @@ struct Args {
#[arg(long)]
temperature: Option<f64>,
+ /// Nucleus sampling probability cutoff.
+ #[arg(long)]
+ top_p: Option<f64>,
+
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
@@ -149,7 +154,14 @@ fn main() -> Result<()> {
let model = GPTBigCode::load(vb, config)?;
println!("loaded the model in {:?}", start.elapsed());
- let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, args.temperature, &device);
+ let mut pipeline = TextGeneration::new(
+ model,
+ tokenizer,
+ args.seed,
+ args.temperature,
+ args.top_p,
+ &device,
+ );
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}
diff --git a/candle-examples/examples/falcon/main.rs b/candle-examples/examples/falcon/main.rs
index c45fe545..b0973d64 100644
--- a/candle-examples/examples/falcon/main.rs
+++ b/candle-examples/examples/falcon/main.rs
@@ -25,17 +25,25 @@ struct TextGeneration {
repeat_last_n: usize,
}
+struct GenerationOptions {
+ temp: Option<f64>,
+ top_p: Option<f64>,
+ repeat_penalty: f32,
+ repeat_last_n: usize,
+}
+
impl TextGeneration {
fn new(
model: Falcon,
tokenizer: Tokenizer,
+ generation_options: GenerationOptions,
seed: u64,
- temp: Option<f64>,
device: &Device,
- repeat_penalty: f32,
- repeat_last_n: usize,
) -> Self {
- let logits_processor = LogitsProcessor::new(seed, temp);
+ let logits_processor =
+ LogitsProcessor::new(seed, generation_options.temp, generation_options.top_p);
+ let repeat_penalty = generation_options.repeat_penalty;
+ let repeat_last_n = generation_options.repeat_last_n;
Self {
model,
tokenizer,
@@ -118,6 +126,10 @@ struct Args {
#[arg(long)]
temperature: Option<f64>,
+ /// Nucleus sampling probability cutoff.
+ #[arg(long)]
+ top_p: Option<f64>,
+
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
@@ -185,15 +197,14 @@ fn main() -> Result<()> {
let model = Falcon::load(vb, config)?;
println!("loaded the model in {:?}", start.elapsed());
- let mut pipeline = TextGeneration::new(
- model,
- tokenizer,
- args.seed,
- args.temperature,
- &device,
- args.repeat_penalty,
- args.repeat_last_n,
- );
+ let generation_options = GenerationOptions {
+ temp: args.temperature,
+ top_p: args.top_p,
+ repeat_penalty: args.repeat_penalty,
+ repeat_last_n: args.repeat_last_n,
+ };
+ let mut pipeline =
+ TextGeneration::new(model, tokenizer, generation_options, args.seed, &device);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs
index db3d216c..b2d7d938 100644
--- a/candle-examples/examples/llama/main.rs
+++ b/candle-examples/examples/llama/main.rs
@@ -42,6 +42,10 @@ struct Args {
#[arg(long)]
temperature: Option<f64>,
+ /// Nucleus sampling probability cutoff.
+ #[arg(long)]
+ top_p: Option<f64>,
+
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
@@ -193,7 +197,7 @@ fn main() -> Result<()> {
println!("starting the inference loop");
print!("{prompt}");
- let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
+ let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
let start_gen = std::time::Instant::now();
let mut index_pos = 0;
let mut token_generated = 0;
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs
index e0ade322..e752a494 100644
--- a/candle-examples/examples/llama2-c/main.rs
+++ b/candle-examples/examples/llama2-c/main.rs
@@ -27,6 +27,10 @@ struct InferenceCmd {
#[arg(long)]
temperature: Option<f64>,
+ /// Nucleus sampling probability cutoff.
+ #[arg(long)]
+ top_p: Option<f64>,
+
#[arg(long, default_value = "")]
prompt: String,
@@ -133,6 +137,7 @@ fn main() -> anyhow::Result<()> {
None => {
let cmd = InferenceCmd {
temperature: None,
+ top_p: None,
prompt: "".to_string(),
config: None,
model_id: "karpathy/tinyllamas".to_string(),
@@ -256,7 +261,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let model = Llama::load(vb, &cache, config)?;
println!("starting the inference loop");
- let mut logits_processor = LogitsProcessor::new(299792458, args.temperature);
+ let mut logits_processor = LogitsProcessor::new(299792458, args.temperature, args.top_p);
let mut index_pos = 0;
print!("{}", args.prompt);
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs
index c8179d33..a80ad420 100644
--- a/candle-examples/examples/quantized/main.rs
+++ b/candle-examples/examples/quantized/main.rs
@@ -71,6 +71,10 @@ struct Args {
#[arg(long, default_value_t = 0.8)]
temperature: f64,
+ /// Nucleus sampling probability cutoff.
+ #[arg(long)]
+ top_p: Option<f64>,
+
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
@@ -310,7 +314,7 @@ fn main() -> anyhow::Result<()> {
prompt_tokens
};
let mut all_tokens = vec![];
- let mut logits_processor = LogitsProcessor::new(args.seed, temperature);
+ let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p);
let start_prompt_processing = std::time::Instant::now();
let mut next_token = {