diff options
Diffstat (limited to 'candle-examples/examples/quantized/main.rs')
-rw-r--r-- | candle-examples/examples/quantized/main.rs | 22 |
1 files changed, 5 insertions, 17 deletions
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index 209a0f55..661cce5a 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -533,22 +533,6 @@ fn print_token(next_token: u32, tokenizer: &Tokenizer) { } } -fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Result<Tensor> { - let mut logits = logits.to_vec1::<f32>()?; - let context: std::collections::HashSet<_> = context.iter().collect(); - for (token_id, logit) in logits.iter_mut().enumerate() { - if context.contains(&(token_id as u32)) { - if *logit >= 0. { - *logit /= penalty - } else { - *logit *= penalty - } - } - } - let logits_len = logits.len(); - Tensor::from_vec(logits, logits_len, &Device::Cpu) -} - fn format_size(size_in_bytes: usize) -> String { if size_in_bytes < 1_000 { format!("{}B", size_in_bytes) @@ -670,7 +654,11 @@ fn main() -> anyhow::Result<()> { logits } else { let start_at = all_tokens.len().saturating_sub(args.repeat_last_n); - apply_repeat_penalty(&logits, args.repeat_penalty, &all_tokens[start_at..])? + candle_transformers::utils::apply_repeat_penalty( + &logits, + args.repeat_penalty, + &all_tokens[start_at..], + )? }; next_token = logits_processor.sample(&logits)?; all_tokens.push(next_token); |