summaryrefslogtreecommitdiff
path: root/candle-wasm-examples/llama2-c
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-24 15:28:27 +0100
committerGitHub <noreply@github.com>2023-07-24 15:28:27 +0100
commit160ba09d3062ede06a770ca4b8fc5c42b16a2d6a (patch)
tree0bea1eec9338fa8bbb93e010fcc1d398918124e4 /candle-wasm-examples/llama2-c
parent5a26cba7339e326eaca7a10ee99f6af948da2677 (diff)
downloadcandle-160ba09d3062ede06a770ca4b8fc5c42b16a2d6a.tar.gz
candle-160ba09d3062ede06a770ca4b8fc5c42b16a2d6a.tar.bz2
candle-160ba09d3062ede06a770ca4b8fc5c42b16a2d6a.zip
Polish the llama2 wasm ui. (#232)
* Polish the llama2 wasm ui. * readme update.
Diffstat (limited to 'candle-wasm-examples/llama2-c')
-rw-r--r--candle-wasm-examples/llama2-c/Cargo.toml1
-rw-r--r--candle-wasm-examples/llama2-c/src/app.rs19
-rw-r--r--candle-wasm-examples/llama2-c/src/bin/app.rs1
-rw-r--r--candle-wasm-examples/llama2-c/src/bin/worker.rs1
-rw-r--r--candle-wasm-examples/llama2-c/src/model.rs2
-rw-r--r--candle-wasm-examples/llama2-c/src/worker.rs18
6 files changed, 34 insertions, 8 deletions
diff --git a/candle-wasm-examples/llama2-c/Cargo.toml b/candle-wasm-examples/llama2-c/Cargo.toml
index 22d9cfe8..6aae0e59 100644
--- a/candle-wasm-examples/llama2-c/Cargo.toml
+++ b/candle-wasm-examples/llama2-c/Cargo.toml
@@ -24,6 +24,7 @@ serde = { workspace = true }
serde_json = { workspace = true }
# Wasm specific crates.
+console_error_panic_hook = "0.1.7"
getrandom = { version = "0.2", features = ["js"] }
gloo = "0.8"
js-sys = "0.3.64"
diff --git a/candle-wasm-examples/llama2-c/src/app.rs b/candle-wasm-examples/llama2-c/src/app.rs
index 460ac053..eab0aa6e 100644
--- a/candle-wasm-examples/llama2-c/src/app.rs
+++ b/candle-wasm-examples/llama2-c/src/app.rs
@@ -1,5 +1,6 @@
use crate::console_log;
use crate::worker::{ModelData, Worker, WorkerInput, WorkerOutput};
+use std::str::FromStr;
use wasm_bindgen::prelude::*;
use wasm_bindgen_futures::JsFuture;
use yew::{html, Component, Context, Html};
@@ -42,6 +43,7 @@ pub struct CurrentDecode {
pub struct App {
status: String,
+ temperature: std::rc::Rc<std::cell::RefCell<f64>>,
generated: String,
current_decode: Option<CurrentDecode>,
worker: Box<dyn Bridge<Worker>>,
@@ -73,6 +75,7 @@ impl Component for App {
let worker = Worker::bridge(std::rc::Rc::new(cb));
Self {
status,
+ temperature: std::rc::Rc::new(std::cell::RefCell::new(0.)),
generated: String::new(),
current_decode: None,
worker,
@@ -109,7 +112,10 @@ impl Component for App {
self.current_decode = Some(CurrentDecode { start_time });
self.status = "generating...".to_string();
self.generated.clear();
- ctx.link().send_message(Msg::WorkerInMsg(WorkerInput::Run))
+ let temp = *self.temperature.borrow();
+ console_log!("temp: {}", temp);
+ ctx.link()
+ .send_message(Msg::WorkerInMsg(WorkerInput::Run(temp)))
}
true
}
@@ -151,8 +157,16 @@ impl Component for App {
}
fn view(&self, ctx: &Context<Self>) -> Html {
+ use yew::TargetCast;
+ let temperature = self.temperature.clone();
+ let oninput = move |e: yew::InputEvent| {
+ let input: web_sys::HtmlInputElement = e.target_unchecked_into();
+ if let Ok(temp) = f64::from_str(&input.value()) {
+ *temperature.borrow_mut() = temp
+ }
+ };
html! {
- <div>
+ <div style="margin: 2%;">
<div><p>{"Running "}
<a href="https://github.com/karpathy/llama2.c" target="_blank">{"llama2.c"}</a>
{" in the browser using rust/wasm with "}
@@ -161,6 +175,7 @@ impl Component for App {
<p>{"Once the weights have loaded, click on the run button to start generating content."}
</p>
</div>
+ {"temperature: "}<input type="range" min="0." max="1.2" step="0.1" value={self.temperature.borrow().to_string()} {oninput} id="temp"/>
<button class="button" onclick={ctx.link().callback(move |_| Msg::Run)}> { "run" }</button>
<br/ >
<h3>
diff --git a/candle-wasm-examples/llama2-c/src/bin/app.rs b/candle-wasm-examples/llama2-c/src/bin/app.rs
index 3428f6ff..717eeafc 100644
--- a/candle-wasm-examples/llama2-c/src/bin/app.rs
+++ b/candle-wasm-examples/llama2-c/src/bin/app.rs
@@ -1,4 +1,5 @@
fn main() {
wasm_logger::init(wasm_logger::Config::new(log::Level::Trace));
+ console_error_panic_hook::set_once();
yew::Renderer::<candle_wasm_example_llama2::App>::new().render();
}
diff --git a/candle-wasm-examples/llama2-c/src/bin/worker.rs b/candle-wasm-examples/llama2-c/src/bin/worker.rs
index d8ca2172..accb51b7 100644
--- a/candle-wasm-examples/llama2-c/src/bin/worker.rs
+++ b/candle-wasm-examples/llama2-c/src/bin/worker.rs
@@ -1,4 +1,5 @@
use yew_agent::PublicWorker;
fn main() {
+ console_error_panic_hook::set_once();
candle_wasm_example_llama2::Worker::register();
}
diff --git a/candle-wasm-examples/llama2-c/src/model.rs b/candle-wasm-examples/llama2-c/src/model.rs
index 13f939db..8b0b3c3e 100644
--- a/candle-wasm-examples/llama2-c/src/model.rs
+++ b/candle-wasm-examples/llama2-c/src/model.rs
@@ -20,7 +20,7 @@ pub struct Cache {
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
pub use_kv_cache: bool,
#[allow(clippy::type_complexity)]
- kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
+ pub kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
cos: Tensor,
sin: Tensor,
device: Device,
diff --git a/candle-wasm-examples/llama2-c/src/worker.rs b/candle-wasm-examples/llama2-c/src/worker.rs
index 9b0351d6..d64da8c6 100644
--- a/candle-wasm-examples/llama2-c/src/worker.rs
+++ b/candle-wasm-examples/llama2-c/src/worker.rs
@@ -107,9 +107,11 @@ impl LogitsProcessor {
}
impl Model {
- fn run(&self, link: &WorkerLink<Worker>, id: HandlerId) -> Result<()> {
+ fn run(&self, link: &WorkerLink<Worker>, id: HandlerId, temp: f64) -> Result<()> {
let dev = Device::Cpu;
- let mut logits_processor = LogitsProcessor::new(299792458, None);
+ let temp = if temp <= 0. { None } else { Some(temp) };
+ console_log!("{temp:?}");
+ let mut logits_processor = LogitsProcessor::new(299792458, temp);
let mut index_pos = 0;
let mut tokens = vec![1u32];
@@ -299,7 +301,7 @@ pub struct Worker {
#[derive(Serialize, Deserialize)]
pub enum WorkerInput {
ModelData(ModelData),
- Run,
+ Run(f64),
}
#[derive(Serialize, Deserialize)]
@@ -332,10 +334,16 @@ impl yew_agent::Worker for Worker {
}
Err(err) => Err(format!("model creation error {err:?}")),
},
- WorkerInput::Run => match &self.model {
+ WorkerInput::Run(temp) => match &mut self.model {
None => Err("model has not been set yet".to_string()),
Some(model) => {
- let result = model.run(&self.link, id).map_err(|e| e.to_string());
+ {
+ let mut cache = model.cache.kvs.lock().unwrap();
+ for elem in cache.iter_mut() {
+ *elem = None
+ }
+ }
+ let result = model.run(&self.link, id, temp).map_err(|e| e.to_string());
Ok(WorkerOutput::GenerationDone(result))
}
},