summaryrefslogtreecommitdiff
path: root/candle-examples/examples/phi/main.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-25 20:53:30 +0100
committerGitHub <noreply@github.com>2023-09-25 20:53:30 +0100
commitc78a29432329e2947e0c417d8dc1e5ed6aa25bad (patch)
tree34b7fb2a28a096ce6323a81a99d704a61dd42ca5 /candle-examples/examples/phi/main.rs
parenta36d88325492c5e721b2045b2988cc11ef44fc0e (diff)
downloadcandle-c78a29432329e2947e0c417d8dc1e5ed6aa25bad.tar.gz
candle-c78a29432329e2947e0c417d8dc1e5ed6aa25bad.tar.bz2
candle-c78a29432329e2947e0c417d8dc1e5ed6aa25bad.zip
Add some repeat penalty to the phi example. (#961)
Diffstat (limited to 'candle-examples/examples/phi/main.rs')
-rw-r--r--candle-examples/examples/phi/main.rs41
1 files changed, 40 insertions, 1 deletions
diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs
index 5cdbb4b9..fe365e18 100644
--- a/candle-examples/examples/phi/main.rs
+++ b/candle-examples/examples/phi/main.rs
@@ -26,15 +26,20 @@ struct TextGeneration {
device: Device,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
+ repeat_penalty: f32,
+ repeat_last_n: usize,
}
impl TextGeneration {
+ #[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
+ repeat_penalty: f32,
+ repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
@@ -42,6 +47,8 @@ impl TextGeneration {
model,
tokenizer,
logits_processor,
+ repeat_penalty,
+ repeat_last_n,
device: device.clone(),
}
}
@@ -69,6 +76,16 @@ impl TextGeneration {
Model::Quantized(m) => m.forward(&input)?,
};
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
+ let logits = if self.repeat_penalty == 1. {
+ logits
+ } else {
+ let start_at = tokens.len().saturating_sub(self.repeat_last_n);
+ candle_transformers::utils::apply_repeat_penalty(
+ &logits,
+ self.repeat_penalty,
+ &tokens[start_at..],
+ )?
+ };
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
@@ -127,6 +144,14 @@ struct Args {
#[arg(long)]
quantized: bool,
+
+ /// Penalty to be applied for repeating tokens, 1. means no penalty.
+ #[arg(long, default_value_t = 1.1)]
+ repeat_penalty: f32,
+
+ /// The context size to consider for the repeat penalty.
+ #[arg(long, default_value_t = 64)]
+ repeat_last_n: usize,
}
fn main() -> Result<()> {
@@ -134,7 +159,6 @@ fn main() -> Result<()> {
use tracing_subscriber::prelude::*;
let args = Args::parse();
-
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
@@ -142,6 +166,19 @@ fn main() -> Result<()> {
} else {
None
};
+ println!(
+ "avx: {}, neon: {}, simd128: {}, f16c: {}",
+ candle::utils::with_avx(),
+ candle::utils::with_neon(),
+ candle::utils::with_simd128(),
+ candle::utils::with_f16c()
+ );
+ println!(
+ "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
+ args.temperature.unwrap_or(0.),
+ args.repeat_penalty,
+ args.repeat_last_n
+ );
let start = std::time::Instant::now();
let api = Api::new()?;
@@ -185,6 +222,8 @@ fn main() -> Result<()> {
args.seed,
args.temperature,
args.top_p,
+ args.repeat_penalty,
+ args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;