use candle::{DType, Error, Result, Tensor}; use rand::{distributions::Distribution, SeedableRng}; #[derive(Clone, PartialEq, Debug)] pub enum Sampling { ArgMax, All { temperature: f64 }, TopK { k: usize, temperature: f64 }, TopP { p: f64, temperature: f64 }, TopKThenTopP { k: usize, p: f64, temperature: f64 }, } pub struct LogitsProcessor { rng: rand::rngs::StdRng, sampling: Sampling, } impl LogitsProcessor { pub fn from_sampling(seed: u64, sampling: Sampling) -> Self { let rng = rand::rngs::StdRng::seed_from_u64(seed); Self { rng, sampling } } pub fn new(seed: u64, temperature: Option, top_p: Option) -> Self { let temperature = temperature.and_then(|v| if v < 1e-7 { None } else { Some(v) }); let sampling = match temperature { None => Sampling::ArgMax, Some(temperature) => match top_p { None => Sampling::All { temperature }, Some(p) => Sampling::TopP { p, temperature }, }, }; Self::from_sampling(seed, sampling) } fn sample_argmax(&mut self, logits: Tensor) -> Result { let logits_v: Vec = logits.to_vec1()?; let next_token = logits_v .iter() .enumerate() .max_by(|(_, u), (_, v)| u.total_cmp(v)) .map(|(i, _)| i as u32) .unwrap(); Ok(next_token) } fn sample_multinomial(&mut self, prs: &Vec) -> Result { let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?; let next_token = distr.sample(&mut self.rng) as u32; Ok(next_token) } /// top-p sampling (or "nucleus sampling") samples from the smallest set of tokens that exceed /// probability top_p. This way we never sample tokens that have very low probabilities and are /// less likely to go "off the rails". fn sample_topp(&mut self, prs: &mut Vec, top_p: f32) -> Result { let mut argsort_indices = (0..prs.len()).collect::>(); // Sort by descending probability. argsort_indices.sort_by(|&i, &j| prs[j].total_cmp(&prs[i])); // Clamp smaller probabilities to zero. let mut cumsum = 0.; for index in &argsort_indices { if cumsum >= top_p { prs[*index] = 0.0; } else { cumsum += prs[*index]; } } // Sample with clamped probabilities. self.sample_multinomial(prs) } // top-k sampling samples from the k tokens with the largest probabilities. fn sample_topk(&mut self, prs: &mut Vec, top_k: usize) -> Result { if top_k >= prs.len() { self.sample_multinomial(prs) } else { let mut argsort_indices = (0..prs.len()).collect::>(); let (indices, _, _) = argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i])); let prs = indices.iter().map(|&i| prs[i]).collect::>(); let index = self.sample_multinomial(&prs)?; Ok(indices[index as usize] as u32) } } // top-k sampling samples from the k tokens with the largest probabilities. // then top-p sampling. fn sample_topk_topp(&mut self, prs: &mut Vec, top_k: usize, top_p: f32) -> Result { if top_k >= prs.len() { self.sample_topp(prs, top_p) } else { let mut argsort_indices = (0..prs.len()).collect::>(); let (indices, _, _) = argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i])); let mut prs = indices.iter().map(|&i| prs[i]).collect::>(); let sum_p = prs.iter().sum::(); let index = if top_p <= 0.0 || top_p >= sum_p { self.sample_multinomial(&prs)? } else { self.sample_topp(&mut prs, top_p)? }; Ok(indices[index as usize] as u32) } } pub fn sample(&mut self, logits: &Tensor) -> Result { self.sample_f(logits, |_| {}) } pub fn sample_f(&mut self, logits: &Tensor, f: impl FnOnce(&mut [f32])) -> Result { let logits = logits.to_dtype(DType::F32)?; let prs = |temperature: f64| -> Result> { let logits = (&logits / temperature)?; let prs = candle_nn::ops::softmax_last_dim(&logits)?; let mut prs = prs.to_vec1()?; f(&mut prs); Ok(prs) }; let next_token = match &self.sampling { Sampling::ArgMax => self.sample_argmax(logits)?, Sampling::All { temperature } => { let prs = prs(*temperature)?; self.sample_multinomial(&prs)? } Sampling::TopP { p, temperature } => { let mut prs = prs(*temperature)?; if *p <= 0.0 || *p >= 1.0 { // simply sample from the predicted probability distribution self.sample_multinomial(&prs)? } else { // top-p (nucleus) sampling, clamping the least likely tokens to zero self.sample_topp(&mut prs, *p as f32)? } } Sampling::TopK { k, temperature } => { let mut prs = prs(*temperature)?; self.sample_topk(&mut prs, *k)? } Sampling::TopKThenTopP { k, p, temperature } => { let mut prs = prs(*temperature)?; self.sample_topk_topp(&mut prs, *k, *p as f32)? } }; Ok(next_token) } }