diff options
Diffstat (limited to 'candle-examples/examples/bigcode/main.rs')
-rw-r--r-- | candle-examples/examples/bigcode/main.rs | 19 |
1 files changed, 15 insertions, 4 deletions
diff --git a/candle-examples/examples/bigcode/main.rs b/candle-examples/examples/bigcode/main.rs index 652cd47f..5f17109e 100644 --- a/candle-examples/examples/bigcode/main.rs +++ b/candle-examples/examples/bigcode/main.rs @@ -7,8 +7,7 @@ extern crate accelerate_src; use anyhow::{Error as E, Result}; use clap::Parser; -mod model; -use model::{Config, GPTBigCode}; +use candle_transformers::models::bigcode::{Config, GPTBigCode}; use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; @@ -29,9 +28,10 @@ impl TextGeneration { tokenizer: Tokenizer, seed: u64, temp: Option<f64>, + top_p: Option<f64>, device: &Device, ) -> Self { - let logits_processor = LogitsProcessor::new(seed, temp); + let logits_processor = LogitsProcessor::new(seed, temp, top_p); Self { model, tokenizer, @@ -95,6 +95,10 @@ struct Args { #[arg(long)] temperature: Option<f64>, + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option<f64>, + /// The seed to use when generating random samples. #[arg(long, default_value_t = 299792458)] seed: u64, @@ -150,7 +154,14 @@ fn main() -> Result<()> { let model = GPTBigCode::load(vb, config)?; println!("loaded the model in {:?}", start.elapsed()); - let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, args.temperature, &device); + let mut pipeline = TextGeneration::new( + model, + tokenizer, + args.seed, + args.temperature, + args.top_p, + &device, + ); pipeline.run(&args.prompt, args.sample_len)?; Ok(()) } |