summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama2-c/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/llama2-c/main.rs')
-rw-r--r--candle-examples/examples/llama2-c/main.rs7
1 files changed, 6 insertions, 1 deletions
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs
index e0ade322..e752a494 100644
--- a/candle-examples/examples/llama2-c/main.rs
+++ b/candle-examples/examples/llama2-c/main.rs
@@ -27,6 +27,10 @@ struct InferenceCmd {
#[arg(long)]
temperature: Option<f64>,
+ /// Nucleus sampling probability cutoff.
+ #[arg(long)]
+ top_p: Option<f64>,
+
#[arg(long, default_value = "")]
prompt: String,
@@ -133,6 +137,7 @@ fn main() -> anyhow::Result<()> {
None => {
let cmd = InferenceCmd {
temperature: None,
+ top_p: None,
prompt: "".to_string(),
config: None,
model_id: "karpathy/tinyllamas".to_string(),
@@ -256,7 +261,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let model = Llama::load(vb, &cache, config)?;
println!("starting the inference loop");
- let mut logits_processor = LogitsProcessor::new(299792458, args.temperature);
+ let mut logits_processor = LogitsProcessor::new(299792458, args.temperature, args.top_p);
let mut index_pos = 0;
print!("{}", args.prompt);