diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-03-23 15:26:09 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-23 15:26:09 +0100 |
commit | a62a97340c3f11fc7d804d8c6138e3da7e9d7648 (patch) | |
tree | 5abec3335d016a5019991238df4f958534ce3318 /candle-transformers | |
parent | fdfe8fd129a0f755f380d4a38f11207c28fc8ee4 (diff) | |
download | candle-a62a97340c3f11fc7d804d8c6138e3da7e9d7648.tar.gz candle-a62a97340c3f11fc7d804d8c6138e3da7e9d7648.tar.bz2 candle-a62a97340c3f11fc7d804d8c6138e3da7e9d7648.zip |
Add topk sampling. (#1923)
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/generation/mod.rs | 85 | ||||
-rw-r--r-- | candle-transformers/tests/generation_tests.rs | 27 |
2 files changed, 88 insertions, 24 deletions
diff --git a/candle-transformers/src/generation/mod.rs b/candle-transformers/src/generation/mod.rs index b1a567c3..530a6b48 100644 --- a/candle-transformers/src/generation/mod.rs +++ b/candle-transformers/src/generation/mod.rs @@ -1,24 +1,35 @@ use candle::{DType, Error, Result, Tensor}; use rand::{distributions::Distribution, SeedableRng}; +#[derive(Clone, PartialEq, Debug)] +pub enum Sampling { + ArgMax, + All { temperature: f64 }, + TopK { k: usize, temperature: f64 }, + TopP { p: f64, temperature: f64 }, +} + pub struct LogitsProcessor { rng: rand::rngs::StdRng, - temperature: Option<f64>, - top_p: Option<f64>, + sampling: Sampling, } impl LogitsProcessor { + pub fn from_sampling(seed: u64, sampling: Sampling) -> Self { + let rng = rand::rngs::StdRng::seed_from_u64(seed); + Self { rng, sampling } + } + pub fn new(seed: u64, temperature: Option<f64>, top_p: Option<f64>) -> Self { - let temperature = if temperature.map_or(true, |v| v < 1e-7) { - None - } else { - temperature + let temperature = temperature.and_then(|v| if v < 1e-7 { None } else { Some(v) }); + let sampling = match temperature { + None => Sampling::ArgMax, + Some(temperature) => match top_p { + None => Sampling::All { temperature }, + Some(p) => Sampling::TopP { p, temperature }, + }, }; - Self { - rng: rand::rngs::StdRng::seed_from_u64(seed), - temperature, - top_p, - } + Self::from_sampling(seed, sampling) } fn sample_argmax(&mut self, logits: Tensor) -> Result<u32> { @@ -38,14 +49,14 @@ impl LogitsProcessor { Ok(next_token) } + /// top-p sampling (or "nucleus sampling") samples from the smallest set of tokens that exceed + /// probability top_p. This way we never sample tokens that have very low probabilities and are + /// less likely to go "off the rails". fn sample_topp(&mut self, prs: &mut Vec<f32>, top_p: f32) -> Result<u32> { - // top-p sampling (or "nucleus sampling") samples from the smallest set of - // tokens that exceed probability top_p. This way we never sample tokens that - // have very low probabilities and are less likely to go "off the rails". let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>(); // Sort by descending probability. - argsort_indices.sort_by(|&i, &j| prs[j].partial_cmp(&prs[i]).unwrap()); + argsort_indices.sort_by(|&i, &j| prs[j].total_cmp(&prs[i])); // Clamp smaller probabilities to zero. let mut cumsum = 0.; @@ -60,23 +71,49 @@ impl LogitsProcessor { self.sample_multinomial(prs) } + // top-k sampling samples from the k tokens with the largest probabilities. + fn sample_topk(&mut self, prs: &mut Vec<f32>, top_k: usize) -> Result<u32> { + if top_k >= prs.len() { + self.sample_multinomial(prs) + } else { + let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>(); + // Sort by descending probability. + let (indices, _, _) = + argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i])); + let prs = indices.iter().map(|&i| prs[i]).collect::<Vec<_>>(); + let index = self.sample_multinomial(&prs)?; + Ok(indices[index as usize] as u32) + } + } + pub fn sample(&mut self, logits: &Tensor) -> Result<u32> { let logits = logits.to_dtype(DType::F32)?; - let next_token = match self.temperature { - None => self.sample_argmax(logits)?, - Some(temperature) => { - let logits = &(&logits / temperature)?; - let prs = candle_nn::ops::softmax_last_dim(logits)?; - let mut prs: Vec<f32> = prs.to_vec1()?; - let top_p = self.top_p.unwrap_or(1.); - if top_p <= 0.0 || top_p >= 1.0 { + let prs = |temperature: f64| -> Result<Vec<f32>> { + let logits = (&logits / temperature)?; + let prs = candle_nn::ops::softmax_last_dim(&logits)?; + prs.to_vec1() + }; + + let next_token = match &self.sampling { + Sampling::ArgMax => self.sample_argmax(logits)?, + Sampling::All { temperature } => { + let prs = prs(*temperature)?; + self.sample_multinomial(&prs)? + } + Sampling::TopP { p, temperature } => { + let mut prs = prs(*temperature)?; + if *p <= 0.0 || *p >= 1.0 { // simply sample from the predicted probability distribution self.sample_multinomial(&prs)? } else { // top-p (nucleus) sampling, clamping the least likely tokens to zero - self.sample_topp(&mut prs, top_p as f32)? + self.sample_topp(&mut prs, *p as f32)? } } + Sampling::TopK { k, temperature } => { + let mut prs = prs(*temperature)?; + self.sample_topk(&mut prs, *k)? + } }; Ok(next_token) } diff --git a/candle-transformers/tests/generation_tests.rs b/candle-transformers/tests/generation_tests.rs index 76f994d0..cc499a44 100644 --- a/candle-transformers/tests/generation_tests.rs +++ b/candle-transformers/tests/generation_tests.rs @@ -27,3 +27,30 @@ fn sample_with_top_p() -> Result<()> { assert_eq!(token, 2); Ok(()) } + +#[test] +fn sample_with_top_k() -> Result<()> { + let mut logits_process = LogitsProcessor::from_sampling( + 42, + candle_transformers::generation::Sampling::TopK { + k: 1, + temperature: 1.0, + }, + ); + let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?; + let token = logits_process.sample(&logits)?; + assert_eq!(token, 3); + let mut logits_process = LogitsProcessor::from_sampling( + 42, + candle_transformers::generation::Sampling::TopK { + k: 2, + temperature: 1.0, + }, + ); + let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?; + let token = logits_process.sample(&logits)?; + assert_eq!(token, 3); + let token = logits_process.sample(&logits)?; + assert_eq!(token, 2); + Ok(()) +} |