summaryrefslogtreecommitdiff
path: root/candle-wasm-examples/llama2-c
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-02 15:49:43 +0100
committerGitHub <noreply@github.com>2023-08-02 15:49:43 +0100
commit186c308d5158d04a7e0bc503567c3813d5370aad (patch)
tree23409e37216662614b89fd6b4a4912365c0c01b6 /candle-wasm-examples/llama2-c
parent4f17290ce05963ae3416f8224ddda77eb67be299 (diff)
downloadcandle-186c308d5158d04a7e0bc503567c3813d5370aad.tar.gz
candle-186c308d5158d04a7e0bc503567c3813d5370aad.tar.bz2
candle-186c308d5158d04a7e0bc503567c3813d5370aad.zip
Wasm llama2 tweaks (#309)
* Clean-up the llama2.c wasm example. * Use a proper tokenizer.
Diffstat (limited to 'candle-wasm-examples/llama2-c')
-rw-r--r--candle-wasm-examples/llama2-c/Cargo.toml1
-rw-r--r--candle-wasm-examples/llama2-c/index.html2
-rw-r--r--candle-wasm-examples/llama2-c/src/app.rs2
-rw-r--r--candle-wasm-examples/llama2-c/src/lib.rs25
-rw-r--r--candle-wasm-examples/llama2-c/src/model.rs8
-rw-r--r--candle-wasm-examples/llama2-c/src/worker.rs28
6 files changed, 14 insertions, 52 deletions
diff --git a/candle-wasm-examples/llama2-c/Cargo.toml b/candle-wasm-examples/llama2-c/Cargo.toml
index 6a128e19..dab99aee 100644
--- a/candle-wasm-examples/llama2-c/Cargo.toml
+++ b/candle-wasm-examples/llama2-c/Cargo.toml
@@ -12,6 +12,7 @@ license.workspace = true
candle = { path = "../../candle-core", version = "0.1.0", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.1.0" }
num-traits = { workspace = true }
+tokenizers = { workspace = true, features = ["unstable_wasm"] }
# App crates.
anyhow = { workspace = true }
diff --git a/candle-wasm-examples/llama2-c/index.html b/candle-wasm-examples/llama2-c/index.html
index e98e1ecb..6c8fe8ff 100644
--- a/candle-wasm-examples/llama2-c/index.html
+++ b/candle-wasm-examples/llama2-c/index.html
@@ -4,7 +4,7 @@
<meta charset="utf-8" />
<title>Welcome to Candle!</title>
- <link data-trunk rel="copy-file" href="tokenizer.bin" />
+ <link data-trunk rel="copy-file" href="tokenizer.json" />
<link data-trunk rel="copy-file" href="model.bin" />
<link data-trunk rel="rust" href="Cargo.toml" data-bin="app" data-type="main" />
<link data-trunk rel="rust" href="Cargo.toml" data-bin="worker" data-type="worker" />
diff --git a/candle-wasm-examples/llama2-c/src/app.rs b/candle-wasm-examples/llama2-c/src/app.rs
index d4ac6d07..f17cdbe3 100644
--- a/candle-wasm-examples/llama2-c/src/app.rs
+++ b/candle-wasm-examples/llama2-c/src/app.rs
@@ -53,7 +53,7 @@ pub struct App {
}
async fn model_data_load() -> Result<ModelData, JsValue> {
- let tokenizer = fetch_url("tokenizer.bin").await?;
+ let tokenizer = fetch_url("tokenizer.json").await?;
let model = fetch_url("model.bin").await?;
console_log!("{}", model.len());
Ok(ModelData { tokenizer, model })
diff --git a/candle-wasm-examples/llama2-c/src/lib.rs b/candle-wasm-examples/llama2-c/src/lib.rs
index 61154d04..b6b4004f 100644
--- a/candle-wasm-examples/llama2-c/src/lib.rs
+++ b/candle-wasm-examples/llama2-c/src/lib.rs
@@ -1,28 +1,3 @@
-#![allow(dead_code)]
-
-pub const WITH_TIMER: bool = true;
-
-struct Timer {
- label: &'static str,
-}
-
-impl Timer {
- fn new(label: &'static str) -> Self {
- if WITH_TIMER {
- web_sys::console::time_with_label(label);
- }
- Self { label }
- }
-}
-
-impl Drop for Timer {
- fn drop(&mut self) {
- if WITH_TIMER {
- web_sys::console::time_end_with_label(self.label)
- }
- }
-}
-
mod app;
mod model;
mod worker;
diff --git a/candle-wasm-examples/llama2-c/src/model.rs b/candle-wasm-examples/llama2-c/src/model.rs
index 8cf53c2a..3231cabf 100644
--- a/candle-wasm-examples/llama2-c/src/model.rs
+++ b/candle-wasm-examples/llama2-c/src/model.rs
@@ -106,14 +106,15 @@ struct CausalSelfAttention {
n_key_value_head: usize,
head_dim: usize,
cache: Cache,
- max_seq_len: usize,
}
impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (b_sz, seq_len, h, n_embd) = x.dims4()?;
- let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
- let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
+ let cos = self.cache.cos.i(index_pos..index_pos + seq_len)?;
+ let sin = self.cache.sin.i(index_pos..index_pos + seq_len)?;
+ let cos = cos.unsqueeze(1)?;
+ let sin = sin.unsqueeze(1)?;
let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
let sin = sin.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;
@@ -196,7 +197,6 @@ impl CausalSelfAttention {
n_key_value_head: cfg.n_kv_heads,
head_dim: cfg.dim / cfg.n_heads,
cache: cache.clone(),
- max_seq_len: cfg.seq_len,
})
}
}
diff --git a/candle-wasm-examples/llama2-c/src/worker.rs b/candle-wasm-examples/llama2-c/src/worker.rs
index 79f7c1fd..3a43c57a 100644
--- a/candle-wasm-examples/llama2-c/src/worker.rs
+++ b/candle-wasm-examples/llama2-c/src/worker.rs
@@ -4,6 +4,7 @@ use candle::{DType, Device, IndexOp, Result, Shape, Tensor, D};
use candle_nn::{ops::softmax, VarBuilder};
use rand::{distributions::Distribution, SeedableRng};
use serde::{Deserialize, Serialize};
+use tokenizers::Tokenizer;
use wasm_bindgen::prelude::*;
use yew_agent::{HandlerId, Public, WorkerLink};
@@ -48,23 +49,6 @@ fn read_tensor<R: std::io::Read, S: Into<Shape>>(
Ok(tensor)
}
-struct Tokenizer {
- tokens: Vec<String>,
-}
-
-impl Tokenizer {
- fn from_reader<R: std::io::Read>(r: &mut R, c: &Config) -> Result<Self> {
- let mut tokens = Vec::with_capacity(c.vocab_size);
- for _token_index in 0..c.vocab_size {
- let token_len = read_i32(r)?;
- let mut token = vec![0u8; token_len as usize];
- r.read_exact(&mut token)?;
- tokens.push(String::from_utf8_lossy(&token).into_owned())
- }
- Ok(Self { tokens })
- }
-}
-
struct Model {
cache: Cache,
config: Config,
@@ -129,8 +113,10 @@ impl Model {
let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);
- let token = self.tokenizer.tokens[next_token as usize].clone();
- link.respond(id, Ok(WorkerOutput::Generated(token)));
+ if let Some(text) = self.tokenizer.id_to_token(next_token) {
+ let text = text.replace('▁', " ").replace("<0x0A>", "\n");
+ link.respond(id, Ok(WorkerOutput::Generated(text)));
+ }
}
Ok(())
}
@@ -282,8 +268,8 @@ impl Model {
let vb = weights.var_builder(&config, &dev)?;
let cache = Cache::new(true, &config, vb.pp("rot"))?;
let llama = Llama::load(vb, &cache, &config)?;
- let mut tokenizer = std::io::Cursor::new(md.tokenizer);
- let tokenizer = Tokenizer::from_reader(&mut tokenizer, &config)?;
+ let tokenizer =
+ Tokenizer::from_bytes(&md.tokenizer).map_err(|m| candle::Error::Msg(m.to_string()))?;
Ok(Self {
cache,
config,