summaryrefslogtreecommitdiff
path: root/candle-examples/examples/mistral
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/mistral')
-rw-r--r--candle-examples/examples/mistral/main.rs24
1 files changed, 22 insertions, 2 deletions
diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs
index c00af3fe..6aa3f51e 100644
--- a/candle-examples/examples/mistral/main.rs
+++ b/candle-examples/examples/mistral/main.rs
@@ -13,7 +13,7 @@ use candle_transformers::models::quantized_mistral::Model as QMistral;
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
-use candle_transformers::generation::LogitsProcessor;
+use candle_transformers::generation::{LogitsProcessor, Sampling};
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
@@ -39,11 +39,26 @@ impl TextGeneration {
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
+ top_k: Option<usize>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
- let logits_processor = LogitsProcessor::new(seed, temp, top_p);
+ let logits_processor = {
+ let temperature = temp.unwrap_or(0.);
+ let sampling = if temperature <= 0. {
+ Sampling::ArgMax
+ } else {
+ match (top_k, top_p) {
+ (None, None) => Sampling::All { temperature },
+ (Some(k), None) => Sampling::TopK { k, temperature },
+ (None, Some(p)) => Sampling::TopP { p, temperature },
+ (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
+ }
+ };
+ LogitsProcessor::from_sampling(seed, sampling)
+ };
+
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
@@ -159,6 +174,10 @@ struct Args {
#[arg(long)]
top_p: Option<f64>,
+ /// Only sample among the top K samples.
+ #[arg(long)]
+ top_k: Option<usize>,
+
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
@@ -314,6 +333,7 @@ fn main() -> Result<()> {
args.seed,
args.temperature,
args.top_p,
+ args.top_k,
args.repeat_penalty,
args.repeat_last_n,
&device,