summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-04 09:27:54 +0200
committerGitHub <noreply@github.com>2024-04-04 09:27:54 +0200
commitf48c07e2428a6d777ffdea57a2d1ac6a7d58a8ee (patch)
treee371fa03e3a8a16ddbbab7563547cec242613d46
parent8967c46563221c01db4fc6a920231a9ef0d6f7bc (diff)
downloadcandle-f48c07e2428a6d777ffdea57a2d1ac6a7d58a8ee.tar.gz
candle-f48c07e2428a6d777ffdea57a2d1ac6a7d58a8ee.tar.bz2
candle-f48c07e2428a6d777ffdea57a2d1ac6a7d58a8ee.zip
Include topk sampling in the quantized example. (#2005)
* Include topk sampling in the quantized example. * Also sample with top-k on the mistral side.
-rw-r--r--candle-examples/examples/mistral/main.rs24
-rw-r--r--candle-examples/examples/quantized/main.rs26
-rw-r--r--candle-transformers/src/generation/mod.rs26
3 files changed, 66 insertions, 10 deletions
diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs
index c00af3fe..6aa3f51e 100644
--- a/candle-examples/examples/mistral/main.rs
+++ b/candle-examples/examples/mistral/main.rs
@@ -13,7 +13,7 @@ use candle_transformers::models::quantized_mistral::Model as QMistral;
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
-use candle_transformers::generation::LogitsProcessor;
+use candle_transformers::generation::{LogitsProcessor, Sampling};
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
@@ -39,11 +39,26 @@ impl TextGeneration {
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
+ top_k: Option<usize>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
- let logits_processor = LogitsProcessor::new(seed, temp, top_p);
+ let logits_processor = {
+ let temperature = temp.unwrap_or(0.);
+ let sampling = if temperature <= 0. {
+ Sampling::ArgMax
+ } else {
+ match (top_k, top_p) {
+ (None, None) => Sampling::All { temperature },
+ (Some(k), None) => Sampling::TopK { k, temperature },
+ (None, Some(p)) => Sampling::TopP { p, temperature },
+ (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
+ }
+ };
+ LogitsProcessor::from_sampling(seed, sampling)
+ };
+
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
@@ -159,6 +174,10 @@ struct Args {
#[arg(long)]
top_p: Option<f64>,
+ /// Only sample among the top K samples.
+ #[arg(long)]
+ top_k: Option<usize>,
+
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
@@ -314,6 +333,7 @@ fn main() -> Result<()> {
args.seed,
args.temperature,
args.top_p,
+ args.top_k,
args.repeat_penalty,
args.repeat_last_n,
&device,
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs
index b03768ed..ea7f70eb 100644
--- a/candle-examples/examples/quantized/main.rs
+++ b/candle-examples/examples/quantized/main.rs
@@ -10,7 +10,7 @@ use tokenizers::Tokenizer;
use candle::quantized::{ggml_file, gguf_file};
use candle::Tensor;
-use candle_transformers::generation::LogitsProcessor;
+use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_transformers::models::quantized_llama as model;
@@ -200,6 +200,10 @@ struct Args {
#[arg(long)]
top_p: Option<f64>,
+ /// Only sample among the top K samples.
+ #[arg(long)]
+ top_k: Option<usize>,
+
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
@@ -349,11 +353,6 @@ fn main() -> anyhow::Result<()> {
#[cfg(feature = "cuda")]
candle::quantized::cuda::set_force_dmmv(args.force_dmmv);
- let temperature = if args.temperature == 0. {
- None
- } else {
- Some(args.temperature)
- };
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
@@ -500,7 +499,20 @@ fn main() -> anyhow::Result<()> {
prompt_tokens
};
let mut all_tokens = vec![];
- let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p);
+ let mut logits_processor = {
+ let temperature = args.temperature;
+ let sampling = if temperature <= 0. {
+ Sampling::ArgMax
+ } else {
+ match (args.top_k, args.top_p) {
+ (None, None) => Sampling::All { temperature },
+ (Some(k), None) => Sampling::TopK { k, temperature },
+ (None, Some(p)) => Sampling::TopP { p, temperature },
+ (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
+ }
+ };
+ LogitsProcessor::from_sampling(args.seed, sampling)
+ };
let start_prompt_processing = std::time::Instant::now();
let mut next_token = if !args.split_prompt {
diff --git a/candle-transformers/src/generation/mod.rs b/candle-transformers/src/generation/mod.rs
index 257d9171..c250a186 100644
--- a/candle-transformers/src/generation/mod.rs
+++ b/candle-transformers/src/generation/mod.rs
@@ -7,6 +7,7 @@ pub enum Sampling {
All { temperature: f64 },
TopK { k: usize, temperature: f64 },
TopP { p: f64, temperature: f64 },
+ TopKThenTopP { k: usize, p: f64, temperature: f64 },
}
pub struct LogitsProcessor {
@@ -77,7 +78,6 @@ impl LogitsProcessor {
self.sample_multinomial(prs)
} else {
let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
- // Sort by descending probability.
let (indices, _, _) =
argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i]));
let prs = indices.iter().map(|&i| prs[i]).collect::<Vec<_>>();
@@ -86,6 +86,26 @@ impl LogitsProcessor {
}
}
+ // top-k sampling samples from the k tokens with the largest probabilities.
+ // then top-p sampling.
+ fn sample_topk_topp(&mut self, prs: &mut Vec<f32>, top_k: usize, top_p: f32) -> Result<u32> {
+ if top_k >= prs.len() {
+ self.sample_topp(prs, top_p)
+ } else {
+ let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
+ let (indices, _, _) =
+ argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i]));
+ let mut prs = indices.iter().map(|&i| prs[i]).collect::<Vec<_>>();
+ let sum_p = prs.iter().sum::<f32>();
+ let index = if top_p <= 0.0 || top_p >= sum_p {
+ self.sample_multinomial(&prs)?
+ } else {
+ self.sample_topp(&mut prs, top_p)?
+ };
+ Ok(indices[index as usize] as u32)
+ }
+ }
+
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
self.sample_f(logits, |_| {})
}
@@ -120,6 +140,10 @@ impl LogitsProcessor {
let mut prs = prs(*temperature)?;
self.sample_topk(&mut prs, *k)?
}
+ Sampling::TopKThenTopP { k, p, temperature } => {
+ let mut prs = prs(*temperature)?;
+ self.sample_topk_topp(&mut prs, *k, *p as f32)?
+ }
};
Ok(next_token)
}