From 805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f Mon Sep 17 00:00:00 2001 From: Juarez Bochi Date: Tue, 12 Sep 2023 09:10:16 -0700 Subject: Implement top_p / nucleus sampling (#819) * Implement top_p / nucleus sampling * Update changelog * rustfmt * Add tests * Fix clippy warning * Fix another clippy error --- candle-wasm-examples/llama2-c/src/bin/m.rs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) (limited to 'candle-wasm-examples/llama2-c/src/bin/m.rs') diff --git a/candle-wasm-examples/llama2-c/src/bin/m.rs b/candle-wasm-examples/llama2-c/src/bin/m.rs index 6628ab7e..61de9d7f 100644 --- a/candle-wasm-examples/llama2-c/src/bin/m.rs +++ b/candle-wasm-examples/llama2-c/src/bin/m.rs @@ -47,7 +47,7 @@ impl Model { tokenizer, model: weights, }); - let logits_processor = LogitsProcessor::new(299792458, None); + let logits_processor = LogitsProcessor::new(299792458, None, None); match model { Ok(inner) => Ok(Self { inner, @@ -69,6 +69,7 @@ impl Model { &mut self, prompt: String, temp: f64, + top_p: f64, repeat_penalty: f32, seed: u64, ) -> Result { @@ -80,7 +81,12 @@ impl Model { } } let temp = if temp <= 0. { None } else { Some(temp) }; - self.logits_processor = LogitsProcessor::new(seed, temp); + let top_p = if top_p <= 0. || top_p >= 1. { + None + } else { + Some(top_p) + }; + self.logits_processor = LogitsProcessor::new(seed, temp, top_p); self.repeat_penalty = repeat_penalty; self.tokens.clear(); let tokens = self -- cgit v1.2.3