diff options
Diffstat (limited to 'candle-transformers/src')
-rw-r--r-- | candle-transformers/src/lib.rs | 1 | ||||
-rw-r--r-- | candle-transformers/src/utils.rs | 18 |
2 files changed, 19 insertions, 0 deletions
diff --git a/candle-transformers/src/lib.rs b/candle-transformers/src/lib.rs index 86cb904e..a8890dc8 100644 --- a/candle-transformers/src/lib.rs +++ b/candle-transformers/src/lib.rs @@ -1,3 +1,4 @@ pub mod generation; pub mod models; pub mod pipelines; +pub mod utils; diff --git a/candle-transformers/src/utils.rs b/candle-transformers/src/utils.rs new file mode 100644 index 00000000..50d3b707 --- /dev/null +++ b/candle-transformers/src/utils.rs @@ -0,0 +1,18 @@ +use candle::{Result, Tensor}; + +pub fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Result<Tensor> { + let device = logits.device(); + let mut logits = logits.to_vec1::<f32>()?; + let context: std::collections::HashSet<_> = context.iter().collect(); + for (token_id, logit) in logits.iter_mut().enumerate() { + if context.contains(&(token_id as u32)) { + if *logit >= 0. { + *logit /= penalty + } else { + *logit *= penalty + } + } + } + let logits_len = logits.len(); + Tensor::from_vec(logits, logits_len, device) +} |