diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-24 12:36:02 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-24 12:36:02 +0100 |
commit | 5a26cba7339e326eaca7a10ee99f6af948da2677 (patch) | |
tree | e7ce4f569f3d620bd73c0bbb00198031345723b2 /candle-wasm-examples/llama2-c/src/app.rs | |
parent | 550a13a5472fd3aa3975c2453eff4bff6ac1d0bd (diff) | |
download | candle-5a26cba7339e326eaca7a10ee99f6af948da2677.tar.gz candle-5a26cba7339e326eaca7a10ee99f6af948da2677.tar.bz2 candle-5a26cba7339e326eaca7a10ee99f6af948da2677.zip |
Re-organize the wasm examples (#231)
* Move the whisper example.
* More renaming.
* Add llama2 as a new wasm example.
* Live generation.
* More of the llama wasm example.
* Formatting.
Diffstat (limited to 'candle-wasm-examples/llama2-c/src/app.rs')
-rw-r--r-- | candle-wasm-examples/llama2-c/src/app.rs | 188 |
1 files changed, 188 insertions, 0 deletions
diff --git a/candle-wasm-examples/llama2-c/src/app.rs b/candle-wasm-examples/llama2-c/src/app.rs new file mode 100644 index 00000000..460ac053 --- /dev/null +++ b/candle-wasm-examples/llama2-c/src/app.rs @@ -0,0 +1,188 @@ +use crate::console_log; +use crate::worker::{ModelData, Worker, WorkerInput, WorkerOutput}; +use wasm_bindgen::prelude::*; +use wasm_bindgen_futures::JsFuture; +use yew::{html, Component, Context, Html}; +use yew_agent::{Bridge, Bridged}; + +async fn fetch_url(url: &str) -> Result<Vec<u8>, JsValue> { + use web_sys::{Request, RequestCache, RequestInit, RequestMode, Response}; + let window = web_sys::window().ok_or("window")?; + let mut opts = RequestInit::new(); + let opts = opts + .method("GET") + .mode(RequestMode::Cors) + .cache(RequestCache::NoCache); + + let request = Request::new_with_str_and_init(url, opts)?; + + let resp_value = JsFuture::from(window.fetch_with_request(&request)).await?; + + // `resp_value` is a `Response` object. + assert!(resp_value.is_instance_of::<Response>()); + let resp: Response = resp_value.dyn_into()?; + let data = JsFuture::from(resp.blob()?).await?; + let blob = web_sys::Blob::from(data); + let array_buffer = JsFuture::from(blob.array_buffer()).await?; + let data = js_sys::Uint8Array::new(&array_buffer).to_vec(); + Ok(data) +} + +pub enum Msg { + Run, + UpdateStatus(String), + SetModel(ModelData), + WorkerInMsg(WorkerInput), + WorkerOutMsg(Result<WorkerOutput, String>), +} + +pub struct CurrentDecode { + start_time: Option<f64>, +} + +pub struct App { + status: String, + generated: String, + current_decode: Option<CurrentDecode>, + worker: Box<dyn Bridge<Worker>>, +} + +async fn model_data_load() -> Result<ModelData, JsValue> { + let tokenizer = fetch_url("tokenizer.bin").await?; + let model = fetch_url("model.bin").await?; + console_log!("{}", model.len()); + Ok(ModelData { tokenizer, model }) +} + +fn performance_now() -> Option<f64> { + let window = web_sys::window()?; + let performance = window.performance()?; + Some(performance.now() / 1000.) +} + +impl Component for App { + type Message = Msg; + type Properties = (); + + fn create(ctx: &Context<Self>) -> Self { + let status = "loading weights".to_string(); + let cb = { + let link = ctx.link().clone(); + move |e| link.send_message(Self::Message::WorkerOutMsg(e)) + }; + let worker = Worker::bridge(std::rc::Rc::new(cb)); + Self { + status, + generated: String::new(), + current_decode: None, + worker, + } + } + + fn rendered(&mut self, ctx: &Context<Self>, first_render: bool) { + if first_render { + ctx.link().send_future(async { + match model_data_load().await { + Err(err) => { + let status = format!("{err:?}"); + Msg::UpdateStatus(status) + } + Ok(model_data) => Msg::SetModel(model_data), + } + }); + } + } + + fn update(&mut self, ctx: &Context<Self>, msg: Self::Message) -> bool { + match msg { + Msg::SetModel(md) => { + self.status = "weights loaded succesfully!".to_string(); + console_log!("loaded weights"); + self.worker.send(WorkerInput::ModelData(md)); + true + } + Msg::Run => { + if self.current_decode.is_some() { + self.status = "already generating some sample at the moment".to_string() + } else { + let start_time = performance_now(); + self.current_decode = Some(CurrentDecode { start_time }); + self.status = "generating...".to_string(); + self.generated.clear(); + ctx.link().send_message(Msg::WorkerInMsg(WorkerInput::Run)) + } + true + } + Msg::WorkerOutMsg(output) => { + match output { + Ok(WorkerOutput::WeightsLoaded) => self.status = "weights loaded!".to_string(), + Ok(WorkerOutput::GenerationDone(Err(err))) => { + self.status = format!("error in worker process: {err}"); + self.current_decode = None + } + Ok(WorkerOutput::GenerationDone(Ok(()))) => { + let dt = self.current_decode.as_ref().and_then(|current_decode| { + current_decode.start_time.and_then(|start_time| { + performance_now().map(|stop_time| stop_time - start_time) + }) + }); + self.status = match dt { + None => "generation succeeded!".to_string(), + Some(dt) => format!("generation succeeded in {:.2}s", dt), + }; + self.current_decode = None + } + Ok(WorkerOutput::Generated(token)) => self.generated.push_str(&token), + Err(err) => { + self.status = format!("error in worker {err:?}"); + } + } + true + } + Msg::WorkerInMsg(inp) => { + self.worker.send(inp); + true + } + Msg::UpdateStatus(status) => { + self.status = status; + true + } + } + } + + fn view(&self, ctx: &Context<Self>) -> Html { + html! { + <div> + <div><p>{"Running "} + <a href="https://github.com/karpathy/llama2.c" target="_blank">{"llama2.c"}</a> + {" in the browser using rust/wasm with "} + <a href="https://github.com/LaurentMazare/candle" target="_blank">{"candle!"}</a> + </p> + <p>{"Once the weights have loaded, click on the run button to start generating content."} + </p> + </div> + <button class="button" onclick={ctx.link().callback(move |_| Msg::Run)}> { "run" }</button> + <br/ > + <h3> + {&self.status} + </h3> + { + if self.current_decode.is_some() { + html! { <progress id="progress-bar" aria-label="generating…"></progress> } + } else { + html! {} + } + } + <blockquote> + <p> { self.generated.chars().map(|c| + if c == '\r' || c == '\n' { + html! { <br/> } + } else { + html! { {c} } + }).collect::<Html>() + } </p> + </blockquote> + </div> + } + } +} |