summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/llama')
-rw-r--r--candle-examples/examples/llama/main.rs18
1 files changed, 18 insertions, 0 deletions
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs
index 2f4a4cd8..6f8766d4 100644
--- a/candle-examples/examples/llama/main.rs
+++ b/candle-examples/examples/llama/main.rs
@@ -83,6 +83,14 @@ struct Args {
/// (same structure as huggingface online)
#[arg(long)]
local_weights: Option<String>,
+
+ /// Penalty to be applied for repeating tokens, 1. means no penalty.
+ #[arg(long, default_value_t = 1.0)]
+ repeat_penalty: f32,
+
+ /// The context size to consider for the repeat penalty.
+ #[arg(long, default_value_t = 64)]
+ repeat_last_n: usize,
}
fn main() -> Result<()> {
@@ -200,6 +208,16 @@ fn main() -> Result<()> {
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = llama.forward(&input, index_pos)?;
let logits = logits.squeeze(0)?;
+ let logits = if args.repeat_penalty == 1. {
+ logits
+ } else {
+ let start_at = tokens.len().saturating_sub(args.repeat_last_n);
+ candle_transformers::utils::apply_repeat_penalty(
+ &logits,
+ args.repeat_penalty,
+ &tokens[start_at..],
+ )?
+ };
index_pos += ctxt.len();
let next_token = logits_processor.sample(&logits)?;