diff options
Diffstat (limited to 'candle-examples/examples/mistral')
-rw-r--r-- | candle-examples/examples/mistral/main.rs | 24 |
1 files changed, 22 insertions, 2 deletions
diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs index c00af3fe..6aa3f51e 100644 --- a/candle-examples/examples/mistral/main.rs +++ b/candle-examples/examples/mistral/main.rs @@ -13,7 +13,7 @@ use candle_transformers::models::quantized_mistral::Model as QMistral; use candle::{DType, Device, Tensor}; use candle_examples::token_output_stream::TokenOutputStream; use candle_nn::VarBuilder; -use candle_transformers::generation::LogitsProcessor; +use candle_transformers::generation::{LogitsProcessor, Sampling}; use hf_hub::{api::sync::Api, Repo, RepoType}; use tokenizers::Tokenizer; @@ -39,11 +39,26 @@ impl TextGeneration { seed: u64, temp: Option<f64>, top_p: Option<f64>, + top_k: Option<usize>, repeat_penalty: f32, repeat_last_n: usize, device: &Device, ) -> Self { - let logits_processor = LogitsProcessor::new(seed, temp, top_p); + let logits_processor = { + let temperature = temp.unwrap_or(0.); + let sampling = if temperature <= 0. { + Sampling::ArgMax + } else { + match (top_k, top_p) { + (None, None) => Sampling::All { temperature }, + (Some(k), None) => Sampling::TopK { k, temperature }, + (None, Some(p)) => Sampling::TopP { p, temperature }, + (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature }, + } + }; + LogitsProcessor::from_sampling(seed, sampling) + }; + Self { model, tokenizer: TokenOutputStream::new(tokenizer), @@ -159,6 +174,10 @@ struct Args { #[arg(long)] top_p: Option<f64>, + /// Only sample among the top K samples. + #[arg(long)] + top_k: Option<usize>, + /// The seed to use when generating random samples. #[arg(long, default_value_t = 299792458)] seed: u64, @@ -314,6 +333,7 @@ fn main() -> Result<()> { args.seed, args.temperature, args.top_p, + args.top_k, args.repeat_penalty, args.repeat_last_n, &device, |