summaryrefslogtreecommitdiff
path: root/candle-examples/examples/bigcode/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/bigcode/main.rs')
-rw-r--r--candle-examples/examples/bigcode/main.rs19
1 files changed, 15 insertions, 4 deletions
diff --git a/candle-examples/examples/bigcode/main.rs b/candle-examples/examples/bigcode/main.rs
index 652cd47f..5f17109e 100644
--- a/candle-examples/examples/bigcode/main.rs
+++ b/candle-examples/examples/bigcode/main.rs
@@ -7,8 +7,7 @@ extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
-mod model;
-use model::{Config, GPTBigCode};
+use candle_transformers::models::bigcode::{Config, GPTBigCode};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
@@ -29,9 +28,10 @@ impl TextGeneration {
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
+ top_p: Option<f64>,
device: &Device,
) -> Self {
- let logits_processor = LogitsProcessor::new(seed, temp);
+ let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer,
@@ -95,6 +95,10 @@ struct Args {
#[arg(long)]
temperature: Option<f64>,
+ /// Nucleus sampling probability cutoff.
+ #[arg(long)]
+ top_p: Option<f64>,
+
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
@@ -150,7 +154,14 @@ fn main() -> Result<()> {
let model = GPTBigCode::load(vb, config)?;
println!("loaded the model in {:?}", start.elapsed());
- let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, args.temperature, &device);
+ let mut pipeline = TextGeneration::new(
+ model,
+ tokenizer,
+ args.seed,
+ args.temperature,
+ args.top_p,
+ &device,
+ );
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}