diff options
Diffstat (limited to 'candle-wasm-examples/llama2-c/src/worker.rs')
-rw-r--r-- | candle-wasm-examples/llama2-c/src/worker.rs | 16 |
1 files changed, 11 insertions, 5 deletions
diff --git a/candle-wasm-examples/llama2-c/src/worker.rs b/candle-wasm-examples/llama2-c/src/worker.rs index 7e97b5da..79dd2f32 100644 --- a/candle-wasm-examples/llama2-c/src/worker.rs +++ b/candle-wasm-examples/llama2-c/src/worker.rs @@ -62,12 +62,18 @@ impl Model { link: &WorkerLink<Worker>, id: HandlerId, temp: f64, + top_p: f64, prompt: String, ) -> Result<()> { let dev = Device::Cpu; let temp = if temp <= 0. { None } else { Some(temp) }; - console_log!("{temp:?} {prompt}"); - let mut logits_processor = LogitsProcessor::new(299792458, temp); + let top_p = if top_p <= 0. || top_p >= 1.0 { + None + } else { + Some(top_p) + }; + console_log!("temp: {temp:?} top_p: {top_p:?} prompt: {prompt}"); + let mut logits_processor = LogitsProcessor::new(299792458, temp, top_p); let mut index_pos = 0; let mut tokens = self .tokenizer @@ -268,7 +274,7 @@ pub struct Worker { #[derive(Serialize, Deserialize)] pub enum WorkerInput { ModelData(ModelData), - Run(f64, String), + Run(f64, f64, String), } #[derive(Serialize, Deserialize)] @@ -301,7 +307,7 @@ impl yew_agent::Worker for Worker { } Err(err) => Err(format!("model creation error {err:?}")), }, - WorkerInput::Run(temp, prompt) => match &mut self.model { + WorkerInput::Run(temp, top_p, prompt) => match &mut self.model { None => Err("model has not been set yet".to_string()), Some(model) => { { @@ -311,7 +317,7 @@ impl yew_agent::Worker for Worker { } } let result = model - .run(&self.link, id, temp, prompt) + .run(&self.link, id, temp, top_p, prompt) .map_err(|e| e.to_string()); Ok(WorkerOutput::GenerationDone(result)) } |