summaryrefslogtreecommitdiff
path: root/candle-wasm-examples/llama2-c
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-02 17:32:36 +0100
committerGitHub <noreply@github.com>2023-08-02 17:32:36 +0100
commit52414ba5c853a2b39b393677a89d07a73fdc7a15 (patch)
tree7ad2a3d9b65c72929b8f55e6fafbbc73fd31821d /candle-wasm-examples/llama2-c
parent186c308d5158d04a7e0bc503567c3813d5370aad (diff)
downloadcandle-52414ba5c853a2b39b393677a89d07a73fdc7a15.tar.gz
candle-52414ba5c853a2b39b393677a89d07a73fdc7a15.tar.bz2
candle-52414ba5c853a2b39b393677a89d07a73fdc7a15.zip
Bugfix for the llama2 wasm example. (#310)
* Clean-up the llama2.c wasm example. * Use a proper tokenizer. * Add a prompt. * Bugfix for the llama2 wasm example.
Diffstat (limited to 'candle-wasm-examples/llama2-c')
-rw-r--r--candle-wasm-examples/llama2-c/src/app.rs15
-rw-r--r--candle-wasm-examples/llama2-c/src/worker.rs31
2 files changed, 37 insertions, 9 deletions
diff --git a/candle-wasm-examples/llama2-c/src/app.rs b/candle-wasm-examples/llama2-c/src/app.rs
index f17cdbe3..782026a4 100644
--- a/candle-wasm-examples/llama2-c/src/app.rs
+++ b/candle-wasm-examples/llama2-c/src/app.rs
@@ -46,6 +46,7 @@ pub struct App {
status: String,
loaded: bool,
temperature: std::rc::Rc<std::cell::RefCell<f64>>,
+ prompt: std::rc::Rc<std::cell::RefCell<String>>,
generated: String,
n_tokens: usize,
current_decode: Option<CurrentDecode>,
@@ -80,6 +81,7 @@ impl Component for App {
status,
n_tokens: 0,
temperature: std::rc::Rc::new(std::cell::RefCell::new(0.)),
+ prompt: std::rc::Rc::new(std::cell::RefCell::new("".to_string())),
generated: String::new(),
current_decode: None,
worker,
@@ -120,9 +122,10 @@ impl Component for App {
self.n_tokens = 0;
self.generated.clear();
let temp = *self.temperature.borrow();
- console_log!("temp: {}", temp);
+ let prompt = self.prompt.borrow().clone();
+ console_log!("temp: {}, prompt: {}", temp, prompt);
ctx.link()
- .send_message(Msg::WorkerInMsg(WorkerInput::Run(temp)))
+ .send_message(Msg::WorkerInMsg(WorkerInput::Run(temp, prompt)))
}
true
}
@@ -181,6 +184,12 @@ impl Component for App {
}
Msg::Refresh
});
+ let prompt = self.prompt.clone();
+ let oninput_prompt = ctx.link().callback(move |e: yew::InputEvent| {
+ let input: web_sys::HtmlInputElement = e.target_unchecked_into();
+ *prompt.borrow_mut() = input.value();
+ Msg::Refresh
+ });
html! {
<div style="margin: 2%;">
<div><p>{"Running "}
@@ -195,6 +204,8 @@ impl Component for App {
<input type="range" min="0." max="1.2" step="0.1" value={self.temperature.borrow().to_string()} {oninput} id="temp"/>
{format!(" \u{00a0} {}", self.temperature.borrow())}
<br/ >
+ {"prompt: "}<input type="text" value={self.prompt.borrow().to_string()} oninput={oninput_prompt} id="prompt"/>
+ <br/ >
{
if self.loaded{
html!(<button class="button" onclick={ctx.link().callback(move |_| Msg::Run)}> { "run" }</button>)
diff --git a/candle-wasm-examples/llama2-c/src/worker.rs b/candle-wasm-examples/llama2-c/src/worker.rs
index 3a43c57a..0ee199af 100644
--- a/candle-wasm-examples/llama2-c/src/worker.rs
+++ b/candle-wasm-examples/llama2-c/src/worker.rs
@@ -91,15 +91,30 @@ impl LogitsProcessor {
}
impl Model {
- fn run(&self, link: &WorkerLink<Worker>, id: HandlerId, temp: f64) -> Result<()> {
+ fn run(
+ &self,
+ link: &WorkerLink<Worker>,
+ id: HandlerId,
+ temp: f64,
+ prompt: String,
+ ) -> Result<()> {
let dev = Device::Cpu;
let temp = if temp <= 0. { None } else { Some(temp) };
- console_log!("{temp:?}");
+ console_log!("{temp:?} {prompt}");
let mut logits_processor = LogitsProcessor::new(299792458, temp);
let mut index_pos = 0;
- let mut tokens = vec![1u32];
+ let mut tokens = self
+ .tokenizer
+ .encode(prompt.to_string(), true)
+ .map_err(|m| candle::Error::Msg(m.to_string()))?
+ .get_ids()
+ .to_vec();
+ link.respond(id, Ok(WorkerOutput::Generated(prompt)));
- for index in 0..self.config.seq_len - 10 {
+ for index in 0.. {
+ if tokens.len() >= self.config.seq_len {
+ break;
+ }
let context_size = if self.cache.use_kv_cache && index > 0 {
1
} else {
@@ -287,7 +302,7 @@ pub struct Worker {
#[derive(Serialize, Deserialize)]
pub enum WorkerInput {
ModelData(ModelData),
- Run(f64),
+ Run(f64, String),
}
#[derive(Serialize, Deserialize)]
@@ -320,7 +335,7 @@ impl yew_agent::Worker for Worker {
}
Err(err) => Err(format!("model creation error {err:?}")),
},
- WorkerInput::Run(temp) => match &mut self.model {
+ WorkerInput::Run(temp, prompt) => match &mut self.model {
None => Err("model has not been set yet".to_string()),
Some(model) => {
{
@@ -329,7 +344,9 @@ impl yew_agent::Worker for Worker {
*elem = None
}
}
- let result = model.run(&self.link, id, temp).map_err(|e| e.to_string());
+ let result = model
+ .run(&self.link, id, temp, prompt)
+ .map_err(|e| e.to_string());
Ok(WorkerOutput::GenerationDone(result))
}
},