diff options
Diffstat (limited to 'candle-examples/examples/llama/main.rs')
-rw-r--r-- | candle-examples/examples/llama/main.rs | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index 6f8766d4..b2d7d938 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -21,11 +21,10 @@ use candle_transformers::generation::LogitsProcessor; use hf_hub::{api::sync::Api, Repo, RepoType}; use std::io::Write; -mod model; +use candle_transformers::models::llama as model; use model::{Config, Llama, LlamaConfig}; const EOS_TOKEN: &str = "</s>"; -const MAX_SEQ_LEN: usize = 4096; const DEFAULT_PROMPT: &str = "My favorite theorem is "; #[derive(Parser, Debug)] @@ -43,6 +42,10 @@ struct Args { #[arg(long)] temperature: Option<f64>, + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option<f64>, + /// The seed to use when generating random samples. #[arg(long, default_value_t = 299792458)] seed: u64, @@ -194,7 +197,7 @@ fn main() -> Result<()> { println!("starting the inference loop"); print!("{prompt}"); - let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature); + let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p); let start_gen = std::time::Instant::now(); let mut index_pos = 0; let mut token_generated = 0; |