diff options
author | Juarez Bochi <jbochi@gmail.com> | 2023-09-12 09:10:16 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-12 18:10:16 +0200 |
commit | 805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f (patch) | |
tree | 0df65e2e6fee356d2345954701ec3d47796ae7ee /candle-wasm-examples/llama2-c/src/bin/m.rs | |
parent | 42da17694a4214a3e39e0d64afc22635ce83f557 (diff) | |
download | candle-805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f.tar.gz candle-805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f.tar.bz2 candle-805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f.zip |
Implement top_p / nucleus sampling (#819)
* Implement top_p / nucleus sampling
* Update changelog
* rustfmt
* Add tests
* Fix clippy warning
* Fix another clippy error
Diffstat (limited to 'candle-wasm-examples/llama2-c/src/bin/m.rs')
-rw-r--r-- | candle-wasm-examples/llama2-c/src/bin/m.rs | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/candle-wasm-examples/llama2-c/src/bin/m.rs b/candle-wasm-examples/llama2-c/src/bin/m.rs index 6628ab7e..61de9d7f 100644 --- a/candle-wasm-examples/llama2-c/src/bin/m.rs +++ b/candle-wasm-examples/llama2-c/src/bin/m.rs @@ -47,7 +47,7 @@ impl Model { tokenizer, model: weights, }); - let logits_processor = LogitsProcessor::new(299792458, None); + let logits_processor = LogitsProcessor::new(299792458, None, None); match model { Ok(inner) => Ok(Self { inner, @@ -69,6 +69,7 @@ impl Model { &mut self, prompt: String, temp: f64, + top_p: f64, repeat_penalty: f32, seed: u64, ) -> Result<String, JsError> { @@ -80,7 +81,12 @@ impl Model { } } let temp = if temp <= 0. { None } else { Some(temp) }; - self.logits_processor = LogitsProcessor::new(seed, temp); + let top_p = if top_p <= 0. || top_p >= 1. { + None + } else { + Some(top_p) + }; + self.logits_processor = LogitsProcessor::new(seed, temp, top_p); self.repeat_penalty = repeat_penalty; self.tokens.clear(); let tokens = self |