From 6e485f2deb65bf21d21c85b4913149e7d2c65c6b Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 27 Aug 2023 10:48:45 +0100 Subject: Add some optional repeat penalty. (#623) * Add some optional repeat penalty. * Add the missing files. --- candle-examples/examples/llama/main.rs | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) (limited to 'candle-examples/examples/llama') diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index 2f4a4cd8..6f8766d4 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -83,6 +83,14 @@ struct Args { /// (same structure as huggingface online) #[arg(long)] local_weights: Option, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.0)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, } fn main() -> Result<()> { @@ -200,6 +208,16 @@ fn main() -> Result<()> { let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; let logits = llama.forward(&input, index_pos)?; let logits = logits.squeeze(0)?; + let logits = if args.repeat_penalty == 1. { + logits + } else { + let start_at = tokens.len().saturating_sub(args.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + args.repeat_penalty, + &tokens[start_at..], + )? + }; index_pos += ctxt.len(); let next_token = logits_processor.sample(&logits)?; -- cgit v1.2.3