diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-04-04 09:27:54 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-04 09:27:54 +0200 |
commit | f48c07e2428a6d777ffdea57a2d1ac6a7d58a8ee (patch) | |
tree | e371fa03e3a8a16ddbbab7563547cec242613d46 /candle-examples/examples/quantized | |
parent | 8967c46563221c01db4fc6a920231a9ef0d6f7bc (diff) | |
download | candle-f48c07e2428a6d777ffdea57a2d1ac6a7d58a8ee.tar.gz candle-f48c07e2428a6d777ffdea57a2d1ac6a7d58a8ee.tar.bz2 candle-f48c07e2428a6d777ffdea57a2d1ac6a7d58a8ee.zip |
Include topk sampling in the quantized example. (#2005)
* Include topk sampling in the quantized example.
* Also sample with top-k on the mistral side.
Diffstat (limited to 'candle-examples/examples/quantized')
-rw-r--r-- | candle-examples/examples/quantized/main.rs | 26 |
1 files changed, 19 insertions, 7 deletions
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index b03768ed..ea7f70eb 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -10,7 +10,7 @@ use tokenizers::Tokenizer; use candle::quantized::{ggml_file, gguf_file}; use candle::Tensor; -use candle_transformers::generation::LogitsProcessor; +use candle_transformers::generation::{LogitsProcessor, Sampling}; use candle_examples::token_output_stream::TokenOutputStream; use candle_transformers::models::quantized_llama as model; @@ -200,6 +200,10 @@ struct Args { #[arg(long)] top_p: Option<f64>, + /// Only sample among the top K samples. + #[arg(long)] + top_k: Option<usize>, + /// The seed to use when generating random samples. #[arg(long, default_value_t = 299792458)] seed: u64, @@ -349,11 +353,6 @@ fn main() -> anyhow::Result<()> { #[cfg(feature = "cuda")] candle::quantized::cuda::set_force_dmmv(args.force_dmmv); - let temperature = if args.temperature == 0. { - None - } else { - Some(args.temperature) - }; let _guard = if args.tracing { let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); tracing_subscriber::registry().with(chrome_layer).init(); @@ -500,7 +499,20 @@ fn main() -> anyhow::Result<()> { prompt_tokens }; let mut all_tokens = vec![]; - let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p); + let mut logits_processor = { + let temperature = args.temperature; + let sampling = if temperature <= 0. { + Sampling::ArgMax + } else { + match (args.top_k, args.top_p) { + (None, None) => Sampling::All { temperature }, + (Some(k), None) => Sampling::TopK { k, temperature }, + (None, Some(p)) => Sampling::TopP { p, temperature }, + (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature }, + } + }; + LogitsProcessor::from_sampling(args.seed, sampling) + }; let start_prompt_processing = std::time::Instant::now(); let mut next_token = if !args.split_prompt { |