summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-27 10:48:45 +0100
committerGitHub <noreply@github.com>2023-08-27 10:48:45 +0100
commit6e485f2deb65bf21d21c85b4913149e7d2c65c6b (patch)
treeebd2d2265e4e321f6cbd13d0a1445a77401a9ab4
parent5320aa6b7d339ff594d3886dd29634ea8cde6f17 (diff)
downloadcandle-6e485f2deb65bf21d21c85b4913149e7d2c65c6b.tar.gz
candle-6e485f2deb65bf21d21c85b4913149e7d2c65c6b.tar.bz2
candle-6e485f2deb65bf21d21c85b4913149e7d2c65c6b.zip
Add some optional repeat penalty. (#623)
* Add some optional repeat penalty. * Add the missing files.
-rw-r--r--candle-examples/examples/llama/main.rs18
-rw-r--r--candle-examples/examples/quantized/main.rs22
-rw-r--r--candle-transformers/src/lib.rs1
-rw-r--r--candle-transformers/src/utils.rs18
4 files changed, 42 insertions, 17 deletions
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs
index 2f4a4cd8..6f8766d4 100644
--- a/candle-examples/examples/llama/main.rs
+++ b/candle-examples/examples/llama/main.rs
@@ -83,6 +83,14 @@ struct Args {
/// (same structure as huggingface online)
#[arg(long)]
local_weights: Option<String>,
+
+ /// Penalty to be applied for repeating tokens, 1. means no penalty.
+ #[arg(long, default_value_t = 1.0)]
+ repeat_penalty: f32,
+
+ /// The context size to consider for the repeat penalty.
+ #[arg(long, default_value_t = 64)]
+ repeat_last_n: usize,
}
fn main() -> Result<()> {
@@ -200,6 +208,16 @@ fn main() -> Result<()> {
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = llama.forward(&input, index_pos)?;
let logits = logits.squeeze(0)?;
+ let logits = if args.repeat_penalty == 1. {
+ logits
+ } else {
+ let start_at = tokens.len().saturating_sub(args.repeat_last_n);
+ candle_transformers::utils::apply_repeat_penalty(
+ &logits,
+ args.repeat_penalty,
+ &tokens[start_at..],
+ )?
+ };
index_pos += ctxt.len();
let next_token = logits_processor.sample(&logits)?;
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs
index 209a0f55..661cce5a 100644
--- a/candle-examples/examples/quantized/main.rs
+++ b/candle-examples/examples/quantized/main.rs
@@ -533,22 +533,6 @@ fn print_token(next_token: u32, tokenizer: &Tokenizer) {
}
}
-fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Result<Tensor> {
- let mut logits = logits.to_vec1::<f32>()?;
- let context: std::collections::HashSet<_> = context.iter().collect();
- for (token_id, logit) in logits.iter_mut().enumerate() {
- if context.contains(&(token_id as u32)) {
- if *logit >= 0. {
- *logit /= penalty
- } else {
- *logit *= penalty
- }
- }
- }
- let logits_len = logits.len();
- Tensor::from_vec(logits, logits_len, &Device::Cpu)
-}
-
fn format_size(size_in_bytes: usize) -> String {
if size_in_bytes < 1_000 {
format!("{}B", size_in_bytes)
@@ -670,7 +654,11 @@ fn main() -> anyhow::Result<()> {
logits
} else {
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
- apply_repeat_penalty(&logits, args.repeat_penalty, &all_tokens[start_at..])?
+ candle_transformers::utils::apply_repeat_penalty(
+ &logits,
+ args.repeat_penalty,
+ &all_tokens[start_at..],
+ )?
};
next_token = logits_processor.sample(&logits)?;
all_tokens.push(next_token);
diff --git a/candle-transformers/src/lib.rs b/candle-transformers/src/lib.rs
index 86cb904e..a8890dc8 100644
--- a/candle-transformers/src/lib.rs
+++ b/candle-transformers/src/lib.rs
@@ -1,3 +1,4 @@
pub mod generation;
pub mod models;
pub mod pipelines;
+pub mod utils;
diff --git a/candle-transformers/src/utils.rs b/candle-transformers/src/utils.rs
new file mode 100644
index 00000000..50d3b707
--- /dev/null
+++ b/candle-transformers/src/utils.rs
@@ -0,0 +1,18 @@
+use candle::{Result, Tensor};
+
+pub fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Result<Tensor> {
+ let device = logits.device();
+ let mut logits = logits.to_vec1::<f32>()?;
+ let context: std::collections::HashSet<_> = context.iter().collect();
+ for (token_id, logit) in logits.iter_mut().enumerate() {
+ if context.contains(&(token_id as u32)) {
+ if *logit >= 0. {
+ *logit /= penalty
+ } else {
+ *logit *= penalty
+ }
+ }
+ }
+ let logits_len = logits.len();
+ Tensor::from_vec(logits, logits_len, device)
+}