summaryrefslogtreecommitdiff
path: root/candle-examples/examples/quantized/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/quantized/main.rs')
-rw-r--r--candle-examples/examples/quantized/main.rs22
1 files changed, 5 insertions, 17 deletions
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs
index 209a0f55..661cce5a 100644
--- a/candle-examples/examples/quantized/main.rs
+++ b/candle-examples/examples/quantized/main.rs
@@ -533,22 +533,6 @@ fn print_token(next_token: u32, tokenizer: &Tokenizer) {
}
}
-fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Result<Tensor> {
- let mut logits = logits.to_vec1::<f32>()?;
- let context: std::collections::HashSet<_> = context.iter().collect();
- for (token_id, logit) in logits.iter_mut().enumerate() {
- if context.contains(&(token_id as u32)) {
- if *logit >= 0. {
- *logit /= penalty
- } else {
- *logit *= penalty
- }
- }
- }
- let logits_len = logits.len();
- Tensor::from_vec(logits, logits_len, &Device::Cpu)
-}
-
fn format_size(size_in_bytes: usize) -> String {
if size_in_bytes < 1_000 {
format!("{}B", size_in_bytes)
@@ -670,7 +654,11 @@ fn main() -> anyhow::Result<()> {
logits
} else {
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
- apply_repeat_penalty(&logits, args.repeat_penalty, &all_tokens[start_at..])?
+ candle_transformers::utils::apply_repeat_penalty(
+ &logits,
+ args.repeat_penalty,
+ &all_tokens[start_at..],
+ )?
};
next_token = logits_processor.sample(&logits)?;
all_tokens.push(next_token);