diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-02 17:32:36 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-02 17:32:36 +0100 |
commit | 52414ba5c853a2b39b393677a89d07a73fdc7a15 (patch) | |
tree | 7ad2a3d9b65c72929b8f55e6fafbbc73fd31821d /candle-wasm-examples/llama2-c | |
parent | 186c308d5158d04a7e0bc503567c3813d5370aad (diff) | |
download | candle-52414ba5c853a2b39b393677a89d07a73fdc7a15.tar.gz candle-52414ba5c853a2b39b393677a89d07a73fdc7a15.tar.bz2 candle-52414ba5c853a2b39b393677a89d07a73fdc7a15.zip |
Bugfix for the llama2 wasm example. (#310)
* Clean-up the llama2.c wasm example.
* Use a proper tokenizer.
* Add a prompt.
* Bugfix for the llama2 wasm example.
Diffstat (limited to 'candle-wasm-examples/llama2-c')
-rw-r--r-- | candle-wasm-examples/llama2-c/src/app.rs | 15 | ||||
-rw-r--r-- | candle-wasm-examples/llama2-c/src/worker.rs | 31 |
2 files changed, 37 insertions, 9 deletions
diff --git a/candle-wasm-examples/llama2-c/src/app.rs b/candle-wasm-examples/llama2-c/src/app.rs index f17cdbe3..782026a4 100644 --- a/candle-wasm-examples/llama2-c/src/app.rs +++ b/candle-wasm-examples/llama2-c/src/app.rs @@ -46,6 +46,7 @@ pub struct App { status: String, loaded: bool, temperature: std::rc::Rc<std::cell::RefCell<f64>>, + prompt: std::rc::Rc<std::cell::RefCell<String>>, generated: String, n_tokens: usize, current_decode: Option<CurrentDecode>, @@ -80,6 +81,7 @@ impl Component for App { status, n_tokens: 0, temperature: std::rc::Rc::new(std::cell::RefCell::new(0.)), + prompt: std::rc::Rc::new(std::cell::RefCell::new("".to_string())), generated: String::new(), current_decode: None, worker, @@ -120,9 +122,10 @@ impl Component for App { self.n_tokens = 0; self.generated.clear(); let temp = *self.temperature.borrow(); - console_log!("temp: {}", temp); + let prompt = self.prompt.borrow().clone(); + console_log!("temp: {}, prompt: {}", temp, prompt); ctx.link() - .send_message(Msg::WorkerInMsg(WorkerInput::Run(temp))) + .send_message(Msg::WorkerInMsg(WorkerInput::Run(temp, prompt))) } true } @@ -181,6 +184,12 @@ impl Component for App { } Msg::Refresh }); + let prompt = self.prompt.clone(); + let oninput_prompt = ctx.link().callback(move |e: yew::InputEvent| { + let input: web_sys::HtmlInputElement = e.target_unchecked_into(); + *prompt.borrow_mut() = input.value(); + Msg::Refresh + }); html! { <div style="margin: 2%;"> <div><p>{"Running "} @@ -195,6 +204,8 @@ impl Component for App { <input type="range" min="0." max="1.2" step="0.1" value={self.temperature.borrow().to_string()} {oninput} id="temp"/> {format!(" \u{00a0} {}", self.temperature.borrow())} <br/ > + {"prompt: "}<input type="text" value={self.prompt.borrow().to_string()} oninput={oninput_prompt} id="prompt"/> + <br/ > { if self.loaded{ html!(<button class="button" onclick={ctx.link().callback(move |_| Msg::Run)}> { "run" }</button>) diff --git a/candle-wasm-examples/llama2-c/src/worker.rs b/candle-wasm-examples/llama2-c/src/worker.rs index 3a43c57a..0ee199af 100644 --- a/candle-wasm-examples/llama2-c/src/worker.rs +++ b/candle-wasm-examples/llama2-c/src/worker.rs @@ -91,15 +91,30 @@ impl LogitsProcessor { } impl Model { - fn run(&self, link: &WorkerLink<Worker>, id: HandlerId, temp: f64) -> Result<()> { + fn run( + &self, + link: &WorkerLink<Worker>, + id: HandlerId, + temp: f64, + prompt: String, + ) -> Result<()> { let dev = Device::Cpu; let temp = if temp <= 0. { None } else { Some(temp) }; - console_log!("{temp:?}"); + console_log!("{temp:?} {prompt}"); let mut logits_processor = LogitsProcessor::new(299792458, temp); let mut index_pos = 0; - let mut tokens = vec![1u32]; + let mut tokens = self + .tokenizer + .encode(prompt.to_string(), true) + .map_err(|m| candle::Error::Msg(m.to_string()))? + .get_ids() + .to_vec(); + link.respond(id, Ok(WorkerOutput::Generated(prompt))); - for index in 0..self.config.seq_len - 10 { + for index in 0.. { + if tokens.len() >= self.config.seq_len { + break; + } let context_size = if self.cache.use_kv_cache && index > 0 { 1 } else { @@ -287,7 +302,7 @@ pub struct Worker { #[derive(Serialize, Deserialize)] pub enum WorkerInput { ModelData(ModelData), - Run(f64), + Run(f64, String), } #[derive(Serialize, Deserialize)] @@ -320,7 +335,7 @@ impl yew_agent::Worker for Worker { } Err(err) => Err(format!("model creation error {err:?}")), }, - WorkerInput::Run(temp) => match &mut self.model { + WorkerInput::Run(temp, prompt) => match &mut self.model { None => Err("model has not been set yet".to_string()), Some(model) => { { @@ -329,7 +344,9 @@ impl yew_agent::Worker for Worker { *elem = None } } - let result = model.run(&self.link, id, temp).map_err(|e| e.to_string()); + let result = model + .run(&self.link, id, temp, prompt) + .map_err(|e| e.to_string()); Ok(WorkerOutput::GenerationDone(result)) } }, |