summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-transformers/src/generation/mod.rs3
1 files changed, 2 insertions, 1 deletions
diff --git a/candle-transformers/src/generation/mod.rs b/candle-transformers/src/generation/mod.rs
index d2ac33e9..b1d20168 100644
--- a/candle-transformers/src/generation/mod.rs
+++ b/candle-transformers/src/generation/mod.rs
@@ -16,7 +16,8 @@ impl LogitsProcessor {
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
let logits = logits.to_dtype(DType::F32)?;
- let next_token = if let Some(temperature) = self.temperature {
+ let temperature = self.temperature.unwrap_or(0.);
+ let next_token = if temperature > 0. {
let prs = candle_nn::ops::softmax(&(&logits / temperature)?, D::Minus1)?;
let prs: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?;