diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-25 20:53:30 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-25 20:53:30 +0100 |
commit | c78a29432329e2947e0c417d8dc1e5ed6aa25bad (patch) | |
tree | 34b7fb2a28a096ce6323a81a99d704a61dd42ca5 /candle-examples/examples/phi/main.rs | |
parent | a36d88325492c5e721b2045b2988cc11ef44fc0e (diff) | |
download | candle-c78a29432329e2947e0c417d8dc1e5ed6aa25bad.tar.gz candle-c78a29432329e2947e0c417d8dc1e5ed6aa25bad.tar.bz2 candle-c78a29432329e2947e0c417d8dc1e5ed6aa25bad.zip |
Add some repeat penalty to the phi example. (#961)
Diffstat (limited to 'candle-examples/examples/phi/main.rs')
-rw-r--r-- | candle-examples/examples/phi/main.rs | 41 |
1 files changed, 40 insertions, 1 deletions
diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index 5cdbb4b9..fe365e18 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -26,15 +26,20 @@ struct TextGeneration { device: Device, tokenizer: Tokenizer, logits_processor: LogitsProcessor, + repeat_penalty: f32, + repeat_last_n: usize, } impl TextGeneration { + #[allow(clippy::too_many_arguments)] fn new( model: Model, tokenizer: Tokenizer, seed: u64, temp: Option<f64>, top_p: Option<f64>, + repeat_penalty: f32, + repeat_last_n: usize, device: &Device, ) -> Self { let logits_processor = LogitsProcessor::new(seed, temp, top_p); @@ -42,6 +47,8 @@ impl TextGeneration { model, tokenizer, logits_processor, + repeat_penalty, + repeat_last_n, device: device.clone(), } } @@ -69,6 +76,16 @@ impl TextGeneration { Model::Quantized(m) => m.forward(&input)?, }; let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; + let logits = if self.repeat_penalty == 1. { + logits + } else { + let start_at = tokens.len().saturating_sub(self.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + self.repeat_penalty, + &tokens[start_at..], + )? + }; let next_token = self.logits_processor.sample(&logits)?; tokens.push(next_token); @@ -127,6 +144,14 @@ struct Args { #[arg(long)] quantized: bool, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.1)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, } fn main() -> Result<()> { @@ -134,7 +159,6 @@ fn main() -> Result<()> { use tracing_subscriber::prelude::*; let args = Args::parse(); - let _guard = if args.tracing { let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); tracing_subscriber::registry().with(chrome_layer).init(); @@ -142,6 +166,19 @@ fn main() -> Result<()> { } else { None }; + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + println!( + "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", + args.temperature.unwrap_or(0.), + args.repeat_penalty, + args.repeat_last_n + ); let start = std::time::Instant::now(); let api = Api::new()?; @@ -185,6 +222,8 @@ fn main() -> Result<()> { args.seed, args.temperature, args.top_p, + args.repeat_penalty, + args.repeat_last_n, &device, ); pipeline.run(&args.prompt, args.sample_len)?; |