diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-21 07:51:46 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-21 07:51:46 +0100 |
commit | 912561614f0fb0fc1e9ccff49448d3d2a85302ce (patch) | |
tree | dcc6d8a38dc2e30b2c802ce1c91c45af4ebb9d96 /candle-transformers/src/generation | |
parent | 8c232d706bde5ef9285218dd71a26f4a1cdb5550 (diff) | |
download | candle-912561614f0fb0fc1e9ccff49448d3d2a85302ce.tar.gz candle-912561614f0fb0fc1e9ccff49448d3d2a85302ce.tar.bz2 candle-912561614f0fb0fc1e9ccff49448d3d2a85302ce.zip |
Better handling of zero temperatures. (#532)
Diffstat (limited to 'candle-transformers/src/generation')
-rw-r--r-- | candle-transformers/src/generation/mod.rs | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/candle-transformers/src/generation/mod.rs b/candle-transformers/src/generation/mod.rs index d2ac33e9..b1d20168 100644 --- a/candle-transformers/src/generation/mod.rs +++ b/candle-transformers/src/generation/mod.rs @@ -16,7 +16,8 @@ impl LogitsProcessor { pub fn sample(&mut self, logits: &Tensor) -> Result<u32> { let logits = logits.to_dtype(DType::F32)?; - let next_token = if let Some(temperature) = self.temperature { + let temperature = self.temperature.unwrap_or(0.); + let next_token = if temperature > 0. { let prs = candle_nn::ops::softmax(&(&logits / temperature)?, D::Minus1)?; let prs: Vec<f32> = prs.to_vec1()?; let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?; |