diff options
-rw-r--r-- | candle-transformers/src/generation/mod.rs | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/candle-transformers/src/generation/mod.rs b/candle-transformers/src/generation/mod.rs index d2ac33e9..b1d20168 100644 --- a/candle-transformers/src/generation/mod.rs +++ b/candle-transformers/src/generation/mod.rs @@ -16,7 +16,8 @@ impl LogitsProcessor { pub fn sample(&mut self, logits: &Tensor) -> Result<u32> { let logits = logits.to_dtype(DType::F32)?; - let next_token = if let Some(temperature) = self.temperature { + let temperature = self.temperature.unwrap_or(0.); + let next_token = if temperature > 0. { let prs = candle_nn::ops::softmax(&(&logits / temperature)?, D::Minus1)?; let prs: Vec<f32> = prs.to_vec1()?; let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?; |