summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorlaurent <laurent.mazare@gmail.com>2024-03-23 15:47:39 +0100
committerlaurent <laurent.mazare@gmail.com>2024-03-23 15:47:39 +0100
commit5e70821dd0dacc1b1e1e44d8ec03d0e4a25d41dc (patch)
tree23a4681b8ff344e142a593b2275134fb698d71c6 /candle-transformers
parenta62a97340c3f11fc7d804d8c6138e3da7e9d7648 (diff)
downloadcandle-5e70821dd0dacc1b1e1e44d8ec03d0e4a25d41dc.tar.gz
candle-5e70821dd0dacc1b1e1e44d8ec03d0e4a25d41dc.tar.bz2
candle-5e70821dd0dacc1b1e1e44d8ec03d0e4a25d41dc.zip
Allow for arbitrary temperature modifications.
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/generation/mod.rs8
1 files changed, 7 insertions, 1 deletions
diff --git a/candle-transformers/src/generation/mod.rs b/candle-transformers/src/generation/mod.rs
index 530a6b48..257d9171 100644
--- a/candle-transformers/src/generation/mod.rs
+++ b/candle-transformers/src/generation/mod.rs
@@ -87,11 +87,17 @@ impl LogitsProcessor {
}
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
+ self.sample_f(logits, |_| {})
+ }
+
+ pub fn sample_f(&mut self, logits: &Tensor, f: impl FnOnce(&mut [f32])) -> Result<u32> {
let logits = logits.to_dtype(DType::F32)?;
let prs = |temperature: f64| -> Result<Vec<f32>> {
let logits = (&logits / temperature)?;
let prs = candle_nn::ops::softmax_last_dim(&logits)?;
- prs.to_vec1()
+ let mut prs = prs.to_vec1()?;
+ f(&mut prs);
+ Ok(prs)
};
let next_token = match &self.sampling {