diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-27 10:48:45 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-27 10:48:45 +0100 |
commit | 6e485f2deb65bf21d21c85b4913149e7d2c65c6b (patch) | |
tree | ebd2d2265e4e321f6cbd13d0a1445a77401a9ab4 | |
parent | 5320aa6b7d339ff594d3886dd29634ea8cde6f17 (diff) | |
download | candle-6e485f2deb65bf21d21c85b4913149e7d2c65c6b.tar.gz candle-6e485f2deb65bf21d21c85b4913149e7d2c65c6b.tar.bz2 candle-6e485f2deb65bf21d21c85b4913149e7d2c65c6b.zip |
Add some optional repeat penalty. (#623)
* Add some optional repeat penalty.
* Add the missing files.
-rw-r--r-- | candle-examples/examples/llama/main.rs | 18 | ||||
-rw-r--r-- | candle-examples/examples/quantized/main.rs | 22 | ||||
-rw-r--r-- | candle-transformers/src/lib.rs | 1 | ||||
-rw-r--r-- | candle-transformers/src/utils.rs | 18 |
4 files changed, 42 insertions, 17 deletions
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index 2f4a4cd8..6f8766d4 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -83,6 +83,14 @@ struct Args { /// (same structure as huggingface online) #[arg(long)] local_weights: Option<String>, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.0)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, } fn main() -> Result<()> { @@ -200,6 +208,16 @@ fn main() -> Result<()> { let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; let logits = llama.forward(&input, index_pos)?; let logits = logits.squeeze(0)?; + let logits = if args.repeat_penalty == 1. { + logits + } else { + let start_at = tokens.len().saturating_sub(args.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + args.repeat_penalty, + &tokens[start_at..], + )? + }; index_pos += ctxt.len(); let next_token = logits_processor.sample(&logits)?; diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index 209a0f55..661cce5a 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -533,22 +533,6 @@ fn print_token(next_token: u32, tokenizer: &Tokenizer) { } } -fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Result<Tensor> { - let mut logits = logits.to_vec1::<f32>()?; - let context: std::collections::HashSet<_> = context.iter().collect(); - for (token_id, logit) in logits.iter_mut().enumerate() { - if context.contains(&(token_id as u32)) { - if *logit >= 0. { - *logit /= penalty - } else { - *logit *= penalty - } - } - } - let logits_len = logits.len(); - Tensor::from_vec(logits, logits_len, &Device::Cpu) -} - fn format_size(size_in_bytes: usize) -> String { if size_in_bytes < 1_000 { format!("{}B", size_in_bytes) @@ -670,7 +654,11 @@ fn main() -> anyhow::Result<()> { logits } else { let start_at = all_tokens.len().saturating_sub(args.repeat_last_n); - apply_repeat_penalty(&logits, args.repeat_penalty, &all_tokens[start_at..])? + candle_transformers::utils::apply_repeat_penalty( + &logits, + args.repeat_penalty, + &all_tokens[start_at..], + )? }; next_token = logits_processor.sample(&logits)?; all_tokens.push(next_token); diff --git a/candle-transformers/src/lib.rs b/candle-transformers/src/lib.rs index 86cb904e..a8890dc8 100644 --- a/candle-transformers/src/lib.rs +++ b/candle-transformers/src/lib.rs @@ -1,3 +1,4 @@ pub mod generation; pub mod models; pub mod pipelines; +pub mod utils; diff --git a/candle-transformers/src/utils.rs b/candle-transformers/src/utils.rs new file mode 100644 index 00000000..50d3b707 --- /dev/null +++ b/candle-transformers/src/utils.rs @@ -0,0 +1,18 @@ +use candle::{Result, Tensor}; + +pub fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Result<Tensor> { + let device = logits.device(); + let mut logits = logits.to_vec1::<f32>()?; + let context: std::collections::HashSet<_> = context.iter().collect(); + for (token_id, logit) in logits.iter_mut().enumerate() { + if context.contains(&(token_id as u32)) { + if *logit >= 0. { + *logit /= penalty + } else { + *logit *= penalty + } + } + } + let logits_len = logits.len(); + Tensor::from_vec(logits, logits_len, device) +} |