summaryrefslogtreecommitdiff
path: root/candle-transformers/src/generation
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-21 07:51:46 +0100
committerGitHub <noreply@github.com>2023-08-21 07:51:46 +0100
commit912561614f0fb0fc1e9ccff49448d3d2a85302ce (patch)
treedcc6d8a38dc2e30b2c802ce1c91c45af4ebb9d96 /candle-transformers/src/generation
parent8c232d706bde5ef9285218dd71a26f4a1cdb5550 (diff)
downloadcandle-912561614f0fb0fc1e9ccff49448d3d2a85302ce.tar.gz
candle-912561614f0fb0fc1e9ccff49448d3d2a85302ce.tar.bz2
candle-912561614f0fb0fc1e9ccff49448d3d2a85302ce.zip
Better handling of zero temperatures. (#532)
Diffstat (limited to 'candle-transformers/src/generation')
-rw-r--r--candle-transformers/src/generation/mod.rs3
1 files changed, 2 insertions, 1 deletions
diff --git a/candle-transformers/src/generation/mod.rs b/candle-transformers/src/generation/mod.rs
index d2ac33e9..b1d20168 100644
--- a/candle-transformers/src/generation/mod.rs
+++ b/candle-transformers/src/generation/mod.rs
@@ -16,7 +16,8 @@ impl LogitsProcessor {
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
let logits = logits.to_dtype(DType::F32)?;
- let next_token = if let Some(temperature) = self.temperature {
+ let temperature = self.temperature.unwrap_or(0.);
+ let next_token = if temperature > 0. {
let prs = candle_nn::ops::softmax(&(&logits / temperature)?, D::Minus1)?;
let prs: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?;