summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/llama/main.rs')
-rw-r--r--candle-examples/examples/llama/main.rs9
1 files changed, 6 insertions, 3 deletions
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs
index 6f8766d4..b2d7d938 100644
--- a/candle-examples/examples/llama/main.rs
+++ b/candle-examples/examples/llama/main.rs
@@ -21,11 +21,10 @@ use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use std::io::Write;
-mod model;
+use candle_transformers::models::llama as model;
use model::{Config, Llama, LlamaConfig};
const EOS_TOKEN: &str = "</s>";
-const MAX_SEQ_LEN: usize = 4096;
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
#[derive(Parser, Debug)]
@@ -43,6 +42,10 @@ struct Args {
#[arg(long)]
temperature: Option<f64>,
+ /// Nucleus sampling probability cutoff.
+ #[arg(long)]
+ top_p: Option<f64>,
+
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
@@ -194,7 +197,7 @@ fn main() -> Result<()> {
println!("starting the inference loop");
print!("{prompt}");
- let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
+ let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
let start_gen = std::time::Instant::now();
let mut index_pos = 0;
let mut token_generated = 0;