summaryrefslogtreecommitdiff
path: root/candle-wasm-examples/llama2-c/src/bin/m.rs
diff options
context:
space:
mode:
authorJuarez Bochi <jbochi@gmail.com>2023-09-12 09:10:16 -0700
committerGitHub <noreply@github.com>2023-09-12 18:10:16 +0200
commit805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f (patch)
tree0df65e2e6fee356d2345954701ec3d47796ae7ee /candle-wasm-examples/llama2-c/src/bin/m.rs
parent42da17694a4214a3e39e0d64afc22635ce83f557 (diff)
downloadcandle-805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f.tar.gz
candle-805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f.tar.bz2
candle-805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f.zip
Implement top_p / nucleus sampling (#819)
* Implement top_p / nucleus sampling * Update changelog * rustfmt * Add tests * Fix clippy warning * Fix another clippy error
Diffstat (limited to 'candle-wasm-examples/llama2-c/src/bin/m.rs')
-rw-r--r--candle-wasm-examples/llama2-c/src/bin/m.rs10
1 files changed, 8 insertions, 2 deletions
diff --git a/candle-wasm-examples/llama2-c/src/bin/m.rs b/candle-wasm-examples/llama2-c/src/bin/m.rs
index 6628ab7e..61de9d7f 100644
--- a/candle-wasm-examples/llama2-c/src/bin/m.rs
+++ b/candle-wasm-examples/llama2-c/src/bin/m.rs
@@ -47,7 +47,7 @@ impl Model {
tokenizer,
model: weights,
});
- let logits_processor = LogitsProcessor::new(299792458, None);
+ let logits_processor = LogitsProcessor::new(299792458, None, None);
match model {
Ok(inner) => Ok(Self {
inner,
@@ -69,6 +69,7 @@ impl Model {
&mut self,
prompt: String,
temp: f64,
+ top_p: f64,
repeat_penalty: f32,
seed: u64,
) -> Result<String, JsError> {
@@ -80,7 +81,12 @@ impl Model {
}
}
let temp = if temp <= 0. { None } else { Some(temp) };
- self.logits_processor = LogitsProcessor::new(seed, temp);
+ let top_p = if top_p <= 0. || top_p >= 1. {
+ None
+ } else {
+ Some(top_p)
+ };
+ self.logits_processor = LogitsProcessor::new(seed, temp, top_p);
self.repeat_penalty = repeat_penalty;
self.tokens.clear();
let tokens = self