summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-03-23 15:26:09 +0100
committerGitHub <noreply@github.com>2024-03-23 15:26:09 +0100
commita62a97340c3f11fc7d804d8c6138e3da7e9d7648 (patch)
tree5abec3335d016a5019991238df4f958534ce3318 /candle-transformers
parentfdfe8fd129a0f755f380d4a38f11207c28fc8ee4 (diff)
downloadcandle-a62a97340c3f11fc7d804d8c6138e3da7e9d7648.tar.gz
candle-a62a97340c3f11fc7d804d8c6138e3da7e9d7648.tar.bz2
candle-a62a97340c3f11fc7d804d8c6138e3da7e9d7648.zip
Add topk sampling. (#1923)
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/generation/mod.rs85
-rw-r--r--candle-transformers/tests/generation_tests.rs27
2 files changed, 88 insertions, 24 deletions
diff --git a/candle-transformers/src/generation/mod.rs b/candle-transformers/src/generation/mod.rs
index b1a567c3..530a6b48 100644
--- a/candle-transformers/src/generation/mod.rs
+++ b/candle-transformers/src/generation/mod.rs
@@ -1,24 +1,35 @@
use candle::{DType, Error, Result, Tensor};
use rand::{distributions::Distribution, SeedableRng};
+#[derive(Clone, PartialEq, Debug)]
+pub enum Sampling {
+ ArgMax,
+ All { temperature: f64 },
+ TopK { k: usize, temperature: f64 },
+ TopP { p: f64, temperature: f64 },
+}
+
pub struct LogitsProcessor {
rng: rand::rngs::StdRng,
- temperature: Option<f64>,
- top_p: Option<f64>,
+ sampling: Sampling,
}
impl LogitsProcessor {
+ pub fn from_sampling(seed: u64, sampling: Sampling) -> Self {
+ let rng = rand::rngs::StdRng::seed_from_u64(seed);
+ Self { rng, sampling }
+ }
+
pub fn new(seed: u64, temperature: Option<f64>, top_p: Option<f64>) -> Self {
- let temperature = if temperature.map_or(true, |v| v < 1e-7) {
- None
- } else {
- temperature
+ let temperature = temperature.and_then(|v| if v < 1e-7 { None } else { Some(v) });
+ let sampling = match temperature {
+ None => Sampling::ArgMax,
+ Some(temperature) => match top_p {
+ None => Sampling::All { temperature },
+ Some(p) => Sampling::TopP { p, temperature },
+ },
};
- Self {
- rng: rand::rngs::StdRng::seed_from_u64(seed),
- temperature,
- top_p,
- }
+ Self::from_sampling(seed, sampling)
}
fn sample_argmax(&mut self, logits: Tensor) -> Result<u32> {
@@ -38,14 +49,14 @@ impl LogitsProcessor {
Ok(next_token)
}
+ /// top-p sampling (or "nucleus sampling") samples from the smallest set of tokens that exceed
+ /// probability top_p. This way we never sample tokens that have very low probabilities and are
+ /// less likely to go "off the rails".
fn sample_topp(&mut self, prs: &mut Vec<f32>, top_p: f32) -> Result<u32> {
- // top-p sampling (or "nucleus sampling") samples from the smallest set of
- // tokens that exceed probability top_p. This way we never sample tokens that
- // have very low probabilities and are less likely to go "off the rails".
let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
// Sort by descending probability.
- argsort_indices.sort_by(|&i, &j| prs[j].partial_cmp(&prs[i]).unwrap());
+ argsort_indices.sort_by(|&i, &j| prs[j].total_cmp(&prs[i]));
// Clamp smaller probabilities to zero.
let mut cumsum = 0.;
@@ -60,23 +71,49 @@ impl LogitsProcessor {
self.sample_multinomial(prs)
}
+ // top-k sampling samples from the k tokens with the largest probabilities.
+ fn sample_topk(&mut self, prs: &mut Vec<f32>, top_k: usize) -> Result<u32> {
+ if top_k >= prs.len() {
+ self.sample_multinomial(prs)
+ } else {
+ let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
+ // Sort by descending probability.
+ let (indices, _, _) =
+ argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i]));
+ let prs = indices.iter().map(|&i| prs[i]).collect::<Vec<_>>();
+ let index = self.sample_multinomial(&prs)?;
+ Ok(indices[index as usize] as u32)
+ }
+ }
+
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
let logits = logits.to_dtype(DType::F32)?;
- let next_token = match self.temperature {
- None => self.sample_argmax(logits)?,
- Some(temperature) => {
- let logits = &(&logits / temperature)?;
- let prs = candle_nn::ops::softmax_last_dim(logits)?;
- let mut prs: Vec<f32> = prs.to_vec1()?;
- let top_p = self.top_p.unwrap_or(1.);
- if top_p <= 0.0 || top_p >= 1.0 {
+ let prs = |temperature: f64| -> Result<Vec<f32>> {
+ let logits = (&logits / temperature)?;
+ let prs = candle_nn::ops::softmax_last_dim(&logits)?;
+ prs.to_vec1()
+ };
+
+ let next_token = match &self.sampling {
+ Sampling::ArgMax => self.sample_argmax(logits)?,
+ Sampling::All { temperature } => {
+ let prs = prs(*temperature)?;
+ self.sample_multinomial(&prs)?
+ }
+ Sampling::TopP { p, temperature } => {
+ let mut prs = prs(*temperature)?;
+ if *p <= 0.0 || *p >= 1.0 {
// simply sample from the predicted probability distribution
self.sample_multinomial(&prs)?
} else {
// top-p (nucleus) sampling, clamping the least likely tokens to zero
- self.sample_topp(&mut prs, top_p as f32)?
+ self.sample_topp(&mut prs, *p as f32)?
}
}
+ Sampling::TopK { k, temperature } => {
+ let mut prs = prs(*temperature)?;
+ self.sample_topk(&mut prs, *k)?
+ }
};
Ok(next_token)
}
diff --git a/candle-transformers/tests/generation_tests.rs b/candle-transformers/tests/generation_tests.rs
index 76f994d0..cc499a44 100644
--- a/candle-transformers/tests/generation_tests.rs
+++ b/candle-transformers/tests/generation_tests.rs
@@ -27,3 +27,30 @@ fn sample_with_top_p() -> Result<()> {
assert_eq!(token, 2);
Ok(())
}
+
+#[test]
+fn sample_with_top_k() -> Result<()> {
+ let mut logits_process = LogitsProcessor::from_sampling(
+ 42,
+ candle_transformers::generation::Sampling::TopK {
+ k: 1,
+ temperature: 1.0,
+ },
+ );
+ let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?;
+ let token = logits_process.sample(&logits)?;
+ assert_eq!(token, 3);
+ let mut logits_process = LogitsProcessor::from_sampling(
+ 42,
+ candle_transformers::generation::Sampling::TopK {
+ k: 2,
+ temperature: 1.0,
+ },
+ );
+ let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?;
+ let token = logits_process.sample(&logits)?;
+ assert_eq!(token, 3);
+ let token = logits_process.sample(&logits)?;
+ assert_eq!(token, 2);
+ Ok(())
+}