diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-24 15:28:27 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-24 15:28:27 +0100 |
commit | 160ba09d3062ede06a770ca4b8fc5c42b16a2d6a (patch) | |
tree | 0bea1eec9338fa8bbb93e010fcc1d398918124e4 /candle-wasm-examples/llama2-c | |
parent | 5a26cba7339e326eaca7a10ee99f6af948da2677 (diff) | |
download | candle-160ba09d3062ede06a770ca4b8fc5c42b16a2d6a.tar.gz candle-160ba09d3062ede06a770ca4b8fc5c42b16a2d6a.tar.bz2 candle-160ba09d3062ede06a770ca4b8fc5c42b16a2d6a.zip |
Polish the llama2 wasm ui. (#232)
* Polish the llama2 wasm ui.
* readme update.
Diffstat (limited to 'candle-wasm-examples/llama2-c')
-rw-r--r-- | candle-wasm-examples/llama2-c/Cargo.toml | 1 | ||||
-rw-r--r-- | candle-wasm-examples/llama2-c/src/app.rs | 19 | ||||
-rw-r--r-- | candle-wasm-examples/llama2-c/src/bin/app.rs | 1 | ||||
-rw-r--r-- | candle-wasm-examples/llama2-c/src/bin/worker.rs | 1 | ||||
-rw-r--r-- | candle-wasm-examples/llama2-c/src/model.rs | 2 | ||||
-rw-r--r-- | candle-wasm-examples/llama2-c/src/worker.rs | 18 |
6 files changed, 34 insertions, 8 deletions
diff --git a/candle-wasm-examples/llama2-c/Cargo.toml b/candle-wasm-examples/llama2-c/Cargo.toml index 22d9cfe8..6aae0e59 100644 --- a/candle-wasm-examples/llama2-c/Cargo.toml +++ b/candle-wasm-examples/llama2-c/Cargo.toml @@ -24,6 +24,7 @@ serde = { workspace = true } serde_json = { workspace = true } # Wasm specific crates. +console_error_panic_hook = "0.1.7" getrandom = { version = "0.2", features = ["js"] } gloo = "0.8" js-sys = "0.3.64" diff --git a/candle-wasm-examples/llama2-c/src/app.rs b/candle-wasm-examples/llama2-c/src/app.rs index 460ac053..eab0aa6e 100644 --- a/candle-wasm-examples/llama2-c/src/app.rs +++ b/candle-wasm-examples/llama2-c/src/app.rs @@ -1,5 +1,6 @@ use crate::console_log; use crate::worker::{ModelData, Worker, WorkerInput, WorkerOutput}; +use std::str::FromStr; use wasm_bindgen::prelude::*; use wasm_bindgen_futures::JsFuture; use yew::{html, Component, Context, Html}; @@ -42,6 +43,7 @@ pub struct CurrentDecode { pub struct App { status: String, + temperature: std::rc::Rc<std::cell::RefCell<f64>>, generated: String, current_decode: Option<CurrentDecode>, worker: Box<dyn Bridge<Worker>>, @@ -73,6 +75,7 @@ impl Component for App { let worker = Worker::bridge(std::rc::Rc::new(cb)); Self { status, + temperature: std::rc::Rc::new(std::cell::RefCell::new(0.)), generated: String::new(), current_decode: None, worker, @@ -109,7 +112,10 @@ impl Component for App { self.current_decode = Some(CurrentDecode { start_time }); self.status = "generating...".to_string(); self.generated.clear(); - ctx.link().send_message(Msg::WorkerInMsg(WorkerInput::Run)) + let temp = *self.temperature.borrow(); + console_log!("temp: {}", temp); + ctx.link() + .send_message(Msg::WorkerInMsg(WorkerInput::Run(temp))) } true } @@ -151,8 +157,16 @@ impl Component for App { } fn view(&self, ctx: &Context<Self>) -> Html { + use yew::TargetCast; + let temperature = self.temperature.clone(); + let oninput = move |e: yew::InputEvent| { + let input: web_sys::HtmlInputElement = e.target_unchecked_into(); + if let Ok(temp) = f64::from_str(&input.value()) { + *temperature.borrow_mut() = temp + } + }; html! { - <div> + <div style="margin: 2%;"> <div><p>{"Running "} <a href="https://github.com/karpathy/llama2.c" target="_blank">{"llama2.c"}</a> {" in the browser using rust/wasm with "} @@ -161,6 +175,7 @@ impl Component for App { <p>{"Once the weights have loaded, click on the run button to start generating content."} </p> </div> + {"temperature: "}<input type="range" min="0." max="1.2" step="0.1" value={self.temperature.borrow().to_string()} {oninput} id="temp"/> <button class="button" onclick={ctx.link().callback(move |_| Msg::Run)}> { "run" }</button> <br/ > <h3> diff --git a/candle-wasm-examples/llama2-c/src/bin/app.rs b/candle-wasm-examples/llama2-c/src/bin/app.rs index 3428f6ff..717eeafc 100644 --- a/candle-wasm-examples/llama2-c/src/bin/app.rs +++ b/candle-wasm-examples/llama2-c/src/bin/app.rs @@ -1,4 +1,5 @@ fn main() { wasm_logger::init(wasm_logger::Config::new(log::Level::Trace)); + console_error_panic_hook::set_once(); yew::Renderer::<candle_wasm_example_llama2::App>::new().render(); } diff --git a/candle-wasm-examples/llama2-c/src/bin/worker.rs b/candle-wasm-examples/llama2-c/src/bin/worker.rs index d8ca2172..accb51b7 100644 --- a/candle-wasm-examples/llama2-c/src/bin/worker.rs +++ b/candle-wasm-examples/llama2-c/src/bin/worker.rs @@ -1,4 +1,5 @@ use yew_agent::PublicWorker; fn main() { + console_error_panic_hook::set_once(); candle_wasm_example_llama2::Worker::register(); } diff --git a/candle-wasm-examples/llama2-c/src/model.rs b/candle-wasm-examples/llama2-c/src/model.rs index 13f939db..8b0b3c3e 100644 --- a/candle-wasm-examples/llama2-c/src/model.rs +++ b/candle-wasm-examples/llama2-c/src/model.rs @@ -20,7 +20,7 @@ pub struct Cache { masks: Arc<Mutex<HashMap<usize, Tensor>>>, pub use_kv_cache: bool, #[allow(clippy::type_complexity)] - kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>, + pub kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>, cos: Tensor, sin: Tensor, device: Device, diff --git a/candle-wasm-examples/llama2-c/src/worker.rs b/candle-wasm-examples/llama2-c/src/worker.rs index 9b0351d6..d64da8c6 100644 --- a/candle-wasm-examples/llama2-c/src/worker.rs +++ b/candle-wasm-examples/llama2-c/src/worker.rs @@ -107,9 +107,11 @@ impl LogitsProcessor { } impl Model { - fn run(&self, link: &WorkerLink<Worker>, id: HandlerId) -> Result<()> { + fn run(&self, link: &WorkerLink<Worker>, id: HandlerId, temp: f64) -> Result<()> { let dev = Device::Cpu; - let mut logits_processor = LogitsProcessor::new(299792458, None); + let temp = if temp <= 0. { None } else { Some(temp) }; + console_log!("{temp:?}"); + let mut logits_processor = LogitsProcessor::new(299792458, temp); let mut index_pos = 0; let mut tokens = vec![1u32]; @@ -299,7 +301,7 @@ pub struct Worker { #[derive(Serialize, Deserialize)] pub enum WorkerInput { ModelData(ModelData), - Run, + Run(f64), } #[derive(Serialize, Deserialize)] @@ -332,10 +334,16 @@ impl yew_agent::Worker for Worker { } Err(err) => Err(format!("model creation error {err:?}")), }, - WorkerInput::Run => match &self.model { + WorkerInput::Run(temp) => match &mut self.model { None => Err("model has not been set yet".to_string()), Some(model) => { - let result = model.run(&self.link, id).map_err(|e| e.to_string()); + { + let mut cache = model.cache.kvs.lock().unwrap(); + for elem in cache.iter_mut() { + *elem = None + } + } + let result = model.run(&self.link, id, temp).map_err(|e| e.to_string()); Ok(WorkerOutput::GenerationDone(result)) } }, |