diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-01 21:38:58 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-01 20:38:58 +0100 |
commit | 2c1df6bba1a2017b8b4aec87a725b5b06b48cdab (patch) | |
tree | a08329097c25cb79c7413e21c440f700ecc7ed41 /candle-examples/examples/llama2-c/main.rs | |
parent | 4d56cef58398a5f676ab1fe12d3ecc2e5c4edc66 (diff) | |
download | candle-2c1df6bba1a2017b8b4aec87a725b5b06b48cdab.tar.gz candle-2c1df6bba1a2017b8b4aec87a725b5b06b48cdab.tar.bz2 candle-2c1df6bba1a2017b8b4aec87a725b5b06b48cdab.zip |
Add a repeat penality to the llama2-c command line example. (#713)
* Add a repeat penality to the llama2-c command line example.
* Another fix attempt.
Diffstat (limited to 'candle-examples/examples/llama2-c/main.rs')
-rw-r--r-- | candle-examples/examples/llama2-c/main.rs | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index 418218b6..e0ade322 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -103,6 +103,14 @@ pub struct Args { /// Tokenizer config file. #[arg(long)] tokenizer: Option<String>, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.1)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, } impl Args { @@ -268,6 +276,16 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; let logits = model.forward(&input, index_pos)?; let logits = logits.i((0, logits.dim(1)? - 1))?; + let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() { + logits + } else { + let start_at = tokens.len().saturating_sub(common_args.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + common_args.repeat_penalty, + &tokens[start_at..], + )? + }; index_pos += ctxt.len(); let next_token = logits_processor.sample(&logits)?; |