diff options
Diffstat (limited to 'candle-transformers/src/generation/mod.rs')
-rw-r--r-- | candle-transformers/src/generation/mod.rs | 26 |
1 files changed, 25 insertions, 1 deletions
diff --git a/candle-transformers/src/generation/mod.rs b/candle-transformers/src/generation/mod.rs index 257d9171..c250a186 100644 --- a/candle-transformers/src/generation/mod.rs +++ b/candle-transformers/src/generation/mod.rs @@ -7,6 +7,7 @@ pub enum Sampling { All { temperature: f64 }, TopK { k: usize, temperature: f64 }, TopP { p: f64, temperature: f64 }, + TopKThenTopP { k: usize, p: f64, temperature: f64 }, } pub struct LogitsProcessor { @@ -77,7 +78,6 @@ impl LogitsProcessor { self.sample_multinomial(prs) } else { let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>(); - // Sort by descending probability. let (indices, _, _) = argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i])); let prs = indices.iter().map(|&i| prs[i]).collect::<Vec<_>>(); @@ -86,6 +86,26 @@ impl LogitsProcessor { } } + // top-k sampling samples from the k tokens with the largest probabilities. + // then top-p sampling. + fn sample_topk_topp(&mut self, prs: &mut Vec<f32>, top_k: usize, top_p: f32) -> Result<u32> { + if top_k >= prs.len() { + self.sample_topp(prs, top_p) + } else { + let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>(); + let (indices, _, _) = + argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i])); + let mut prs = indices.iter().map(|&i| prs[i]).collect::<Vec<_>>(); + let sum_p = prs.iter().sum::<f32>(); + let index = if top_p <= 0.0 || top_p >= sum_p { + self.sample_multinomial(&prs)? + } else { + self.sample_topp(&mut prs, top_p)? + }; + Ok(indices[index as usize] as u32) + } + } + pub fn sample(&mut self, logits: &Tensor) -> Result<u32> { self.sample_f(logits, |_| {}) } @@ -120,6 +140,10 @@ impl LogitsProcessor { let mut prs = prs(*temperature)?; self.sample_topk(&mut prs, *k)? } + Sampling::TopKThenTopP { k, p, temperature } => { + let mut prs = prs(*temperature)?; + self.sample_topk_topp(&mut prs, *k, *p as f32)? + } }; Ok(next_token) } |