summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama2-c/main.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-01 21:38:58 +0200
committerGitHub <noreply@github.com>2023-09-01 20:38:58 +0100
commit2c1df6bba1a2017b8b4aec87a725b5b06b48cdab (patch)
treea08329097c25cb79c7413e21c440f700ecc7ed41 /candle-examples/examples/llama2-c/main.rs
parent4d56cef58398a5f676ab1fe12d3ecc2e5c4edc66 (diff)
downloadcandle-2c1df6bba1a2017b8b4aec87a725b5b06b48cdab.tar.gz
candle-2c1df6bba1a2017b8b4aec87a725b5b06b48cdab.tar.bz2
candle-2c1df6bba1a2017b8b4aec87a725b5b06b48cdab.zip
Add a repeat penality to the llama2-c command line example. (#713)
* Add a repeat penality to the llama2-c command line example. * Another fix attempt.
Diffstat (limited to 'candle-examples/examples/llama2-c/main.rs')
-rw-r--r--candle-examples/examples/llama2-c/main.rs18
1 files changed, 18 insertions, 0 deletions
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs
index 418218b6..e0ade322 100644
--- a/candle-examples/examples/llama2-c/main.rs
+++ b/candle-examples/examples/llama2-c/main.rs
@@ -103,6 +103,14 @@ pub struct Args {
/// Tokenizer config file.
#[arg(long)]
tokenizer: Option<String>,
+
+ /// Penalty to be applied for repeating tokens, 1. means no penalty.
+ #[arg(long, default_value_t = 1.1)]
+ repeat_penalty: f32,
+
+ /// The context size to consider for the repeat penalty.
+ #[arg(long, default_value_t = 64)]
+ repeat_last_n: usize,
}
impl Args {
@@ -268,6 +276,16 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = model.forward(&input, index_pos)?;
let logits = logits.i((0, logits.dim(1)? - 1))?;
+ let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() {
+ logits
+ } else {
+ let start_at = tokens.len().saturating_sub(common_args.repeat_last_n);
+ candle_transformers::utils::apply_repeat_penalty(
+ &logits,
+ common_args.repeat_penalty,
+ &tokens[start_at..],
+ )?
+ };
index_pos += ctxt.len();
let next_token = logits_processor.sample(&logits)?;