summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/llama/main.rs24
1 files changed, 21 insertions, 3 deletions
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs
index 72656295..fa7ce81b 100644
--- a/candle-examples/examples/llama/main.rs
+++ b/candle-examples/examples/llama/main.rs
@@ -17,7 +17,7 @@ use clap::{Parser, ValueEnum};
use candle::{DType, Tensor};
use candle_nn::VarBuilder;
-use candle_transformers::generation::LogitsProcessor;
+use candle_transformers::generation::{LogitsProcessor, Sampling};
use hf_hub::{api::sync::Api, Repo, RepoType};
use std::io::Write;
@@ -54,12 +54,16 @@ struct Args {
#[arg(long)]
top_p: Option<f64>,
+ /// Only sample among the top K samples.
+ #[arg(long)]
+ top_k: Option<usize>,
+
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
- #[arg(long, default_value_t = 10000)]
+ #[arg(short = 'n', long, default_value_t = 10000)]
sample_len: usize,
/// Disable the key-value cache.
@@ -166,7 +170,21 @@ fn main() -> Result<()> {
println!("starting the inference loop");
print!("{prompt}");
- let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), args.top_p);
+ let mut logits_processor = {
+ let temperature = args.temperature;
+ let sampling = if temperature <= 0. {
+ Sampling::ArgMax
+ } else {
+ match (args.top_k, args.top_p) {
+ (None, None) => Sampling::All { temperature },
+ (Some(k), None) => Sampling::TopK { k, temperature },
+ (None, Some(p)) => Sampling::TopP { p, temperature },
+ (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
+ }
+ };
+ LogitsProcessor::from_sampling(args.seed, sampling)
+ };
+
let mut start_gen = std::time::Instant::now();
let mut index_pos = 0;
let mut token_generated = 0;