use candle::{Result, Tensor}; pub fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Result { let device = logits.device(); let mut logits = logits.to_vec1::()?; let context: std::collections::HashSet<_> = context.iter().collect(); for (token_id, logit) in logits.iter_mut().enumerate() { if context.contains(&(token_id as u32)) { if *logit >= 0. { *logit /= penalty } else { *logit *= penalty } } } let logits_len = logits.len(); Tensor::from_vec(logits, logits_len, device) }