diff options
author | laurent <laurent.mazare@gmail.com> | 2024-03-23 15:47:39 +0100 |
---|---|---|
committer | laurent <laurent.mazare@gmail.com> | 2024-03-23 15:47:39 +0100 |
commit | 5e70821dd0dacc1b1e1e44d8ec03d0e4a25d41dc (patch) | |
tree | 23a4681b8ff344e142a593b2275134fb698d71c6 /candle-transformers | |
parent | a62a97340c3f11fc7d804d8c6138e3da7e9d7648 (diff) | |
download | candle-5e70821dd0dacc1b1e1e44d8ec03d0e4a25d41dc.tar.gz candle-5e70821dd0dacc1b1e1e44d8ec03d0e4a25d41dc.tar.bz2 candle-5e70821dd0dacc1b1e1e44d8ec03d0e4a25d41dc.zip |
Allow for arbitrary temperature modifications.
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/generation/mod.rs | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/candle-transformers/src/generation/mod.rs b/candle-transformers/src/generation/mod.rs index 530a6b48..257d9171 100644 --- a/candle-transformers/src/generation/mod.rs +++ b/candle-transformers/src/generation/mod.rs @@ -87,11 +87,17 @@ impl LogitsProcessor { } pub fn sample(&mut self, logits: &Tensor) -> Result<u32> { + self.sample_f(logits, |_| {}) + } + + pub fn sample_f(&mut self, logits: &Tensor, f: impl FnOnce(&mut [f32])) -> Result<u32> { let logits = logits.to_dtype(DType::F32)?; let prs = |temperature: f64| -> Result<Vec<f32>> { let logits = (&logits / temperature)?; let prs = candle_nn::ops::softmax_last_dim(&logits)?; - prs.to_vec1() + let mut prs = prs.to_vec1()?; + f(&mut prs); + Ok(prs) }; let next_token = match &self.sampling { |