summaryrefslogtreecommitdiff
path: root/candle-wasm-examples/llama2-c/src/worker.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-wasm-examples/llama2-c/src/worker.rs')
-rw-r--r--candle-wasm-examples/llama2-c/src/worker.rs16
1 files changed, 11 insertions, 5 deletions
diff --git a/candle-wasm-examples/llama2-c/src/worker.rs b/candle-wasm-examples/llama2-c/src/worker.rs
index 7e97b5da..79dd2f32 100644
--- a/candle-wasm-examples/llama2-c/src/worker.rs
+++ b/candle-wasm-examples/llama2-c/src/worker.rs
@@ -62,12 +62,18 @@ impl Model {
link: &WorkerLink<Worker>,
id: HandlerId,
temp: f64,
+ top_p: f64,
prompt: String,
) -> Result<()> {
let dev = Device::Cpu;
let temp = if temp <= 0. { None } else { Some(temp) };
- console_log!("{temp:?} {prompt}");
- let mut logits_processor = LogitsProcessor::new(299792458, temp);
+ let top_p = if top_p <= 0. || top_p >= 1.0 {
+ None
+ } else {
+ Some(top_p)
+ };
+ console_log!("temp: {temp:?} top_p: {top_p:?} prompt: {prompt}");
+ let mut logits_processor = LogitsProcessor::new(299792458, temp, top_p);
let mut index_pos = 0;
let mut tokens = self
.tokenizer
@@ -268,7 +274,7 @@ pub struct Worker {
#[derive(Serialize, Deserialize)]
pub enum WorkerInput {
ModelData(ModelData),
- Run(f64, String),
+ Run(f64, f64, String),
}
#[derive(Serialize, Deserialize)]
@@ -301,7 +307,7 @@ impl yew_agent::Worker for Worker {
}
Err(err) => Err(format!("model creation error {err:?}")),
},
- WorkerInput::Run(temp, prompt) => match &mut self.model {
+ WorkerInput::Run(temp, top_p, prompt) => match &mut self.model {
None => Err("model has not been set yet".to_string()),
Some(model) => {
{
@@ -311,7 +317,7 @@ impl yew_agent::Worker for Worker {
}
}
let result = model
- .run(&self.link, id, temp, prompt)
+ .run(&self.link, id, temp, top_p, prompt)
.map_err(|e| e.to_string());
Ok(WorkerOutput::GenerationDone(result))
}