summaryrefslogtreecommitdiff
path: root/candle-transformers/src/generation/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/generation/mod.rs')
-rw-r--r--candle-transformers/src/generation/mod.rs26
1 files changed, 25 insertions, 1 deletions
diff --git a/candle-transformers/src/generation/mod.rs b/candle-transformers/src/generation/mod.rs
index 257d9171..c250a186 100644
--- a/candle-transformers/src/generation/mod.rs
+++ b/candle-transformers/src/generation/mod.rs
@@ -7,6 +7,7 @@ pub enum Sampling {
All { temperature: f64 },
TopK { k: usize, temperature: f64 },
TopP { p: f64, temperature: f64 },
+ TopKThenTopP { k: usize, p: f64, temperature: f64 },
}
pub struct LogitsProcessor {
@@ -77,7 +78,6 @@ impl LogitsProcessor {
self.sample_multinomial(prs)
} else {
let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
- // Sort by descending probability.
let (indices, _, _) =
argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i]));
let prs = indices.iter().map(|&i| prs[i]).collect::<Vec<_>>();
@@ -86,6 +86,26 @@ impl LogitsProcessor {
}
}
+ // top-k sampling samples from the k tokens with the largest probabilities.
+ // then top-p sampling.
+ fn sample_topk_topp(&mut self, prs: &mut Vec<f32>, top_k: usize, top_p: f32) -> Result<u32> {
+ if top_k >= prs.len() {
+ self.sample_topp(prs, top_p)
+ } else {
+ let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
+ let (indices, _, _) =
+ argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i]));
+ let mut prs = indices.iter().map(|&i| prs[i]).collect::<Vec<_>>();
+ let sum_p = prs.iter().sum::<f32>();
+ let index = if top_p <= 0.0 || top_p >= sum_p {
+ self.sample_multinomial(&prs)?
+ } else {
+ self.sample_topp(&mut prs, top_p)?
+ };
+ Ok(indices[index as usize] as u32)
+ }
+ }
+
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
self.sample_f(logits, |_| {})
}
@@ -120,6 +140,10 @@ impl LogitsProcessor {
let mut prs = prs(*temperature)?;
self.sample_topk(&mut prs, *k)?
}
+ Sampling::TopKThenTopP { k, p, temperature } => {
+ let mut prs = prs(*temperature)?;
+ self.sample_topk_topp(&mut prs, *k, *p as f32)?
+ }
};
Ok(next_token)
}