diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-02 15:49:43 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-02 15:49:43 +0100 |
commit | 186c308d5158d04a7e0bc503567c3813d5370aad (patch) | |
tree | 23409e37216662614b89fd6b4a4912365c0c01b6 /candle-wasm-examples/llama2-c | |
parent | 4f17290ce05963ae3416f8224ddda77eb67be299 (diff) | |
download | candle-186c308d5158d04a7e0bc503567c3813d5370aad.tar.gz candle-186c308d5158d04a7e0bc503567c3813d5370aad.tar.bz2 candle-186c308d5158d04a7e0bc503567c3813d5370aad.zip |
Wasm llama2 tweaks (#309)
* Clean-up the llama2.c wasm example.
* Use a proper tokenizer.
Diffstat (limited to 'candle-wasm-examples/llama2-c')
-rw-r--r-- | candle-wasm-examples/llama2-c/Cargo.toml | 1 | ||||
-rw-r--r-- | candle-wasm-examples/llama2-c/index.html | 2 | ||||
-rw-r--r-- | candle-wasm-examples/llama2-c/src/app.rs | 2 | ||||
-rw-r--r-- | candle-wasm-examples/llama2-c/src/lib.rs | 25 | ||||
-rw-r--r-- | candle-wasm-examples/llama2-c/src/model.rs | 8 | ||||
-rw-r--r-- | candle-wasm-examples/llama2-c/src/worker.rs | 28 |
6 files changed, 14 insertions, 52 deletions
diff --git a/candle-wasm-examples/llama2-c/Cargo.toml b/candle-wasm-examples/llama2-c/Cargo.toml index 6a128e19..dab99aee 100644 --- a/candle-wasm-examples/llama2-c/Cargo.toml +++ b/candle-wasm-examples/llama2-c/Cargo.toml @@ -12,6 +12,7 @@ license.workspace = true candle = { path = "../../candle-core", version = "0.1.0", package = "candle-core" } candle-nn = { path = "../../candle-nn", version = "0.1.0" } num-traits = { workspace = true } +tokenizers = { workspace = true, features = ["unstable_wasm"] } # App crates. anyhow = { workspace = true } diff --git a/candle-wasm-examples/llama2-c/index.html b/candle-wasm-examples/llama2-c/index.html index e98e1ecb..6c8fe8ff 100644 --- a/candle-wasm-examples/llama2-c/index.html +++ b/candle-wasm-examples/llama2-c/index.html @@ -4,7 +4,7 @@ <meta charset="utf-8" /> <title>Welcome to Candle!</title> - <link data-trunk rel="copy-file" href="tokenizer.bin" /> + <link data-trunk rel="copy-file" href="tokenizer.json" /> <link data-trunk rel="copy-file" href="model.bin" /> <link data-trunk rel="rust" href="Cargo.toml" data-bin="app" data-type="main" /> <link data-trunk rel="rust" href="Cargo.toml" data-bin="worker" data-type="worker" /> diff --git a/candle-wasm-examples/llama2-c/src/app.rs b/candle-wasm-examples/llama2-c/src/app.rs index d4ac6d07..f17cdbe3 100644 --- a/candle-wasm-examples/llama2-c/src/app.rs +++ b/candle-wasm-examples/llama2-c/src/app.rs @@ -53,7 +53,7 @@ pub struct App { } async fn model_data_load() -> Result<ModelData, JsValue> { - let tokenizer = fetch_url("tokenizer.bin").await?; + let tokenizer = fetch_url("tokenizer.json").await?; let model = fetch_url("model.bin").await?; console_log!("{}", model.len()); Ok(ModelData { tokenizer, model }) diff --git a/candle-wasm-examples/llama2-c/src/lib.rs b/candle-wasm-examples/llama2-c/src/lib.rs index 61154d04..b6b4004f 100644 --- a/candle-wasm-examples/llama2-c/src/lib.rs +++ b/candle-wasm-examples/llama2-c/src/lib.rs @@ -1,28 +1,3 @@ -#![allow(dead_code)] - -pub const WITH_TIMER: bool = true; - -struct Timer { - label: &'static str, -} - -impl Timer { - fn new(label: &'static str) -> Self { - if WITH_TIMER { - web_sys::console::time_with_label(label); - } - Self { label } - } -} - -impl Drop for Timer { - fn drop(&mut self) { - if WITH_TIMER { - web_sys::console::time_end_with_label(self.label) - } - } -} - mod app; mod model; mod worker; diff --git a/candle-wasm-examples/llama2-c/src/model.rs b/candle-wasm-examples/llama2-c/src/model.rs index 8cf53c2a..3231cabf 100644 --- a/candle-wasm-examples/llama2-c/src/model.rs +++ b/candle-wasm-examples/llama2-c/src/model.rs @@ -106,14 +106,15 @@ struct CausalSelfAttention { n_key_value_head: usize, head_dim: usize, cache: Cache, - max_seq_len: usize, } impl CausalSelfAttention { fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { let (b_sz, seq_len, h, n_embd) = x.dims4()?; - let cos = self.cache.cos.narrow(0, index_pos, seq_len)?; - let sin = self.cache.sin.narrow(0, index_pos, seq_len)?; + let cos = self.cache.cos.i(index_pos..index_pos + seq_len)?; + let sin = self.cache.sin.i(index_pos..index_pos + seq_len)?; + let cos = cos.unsqueeze(1)?; + let sin = sin.unsqueeze(1)?; let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?; let sin = sin.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?; let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?; @@ -196,7 +197,6 @@ impl CausalSelfAttention { n_key_value_head: cfg.n_kv_heads, head_dim: cfg.dim / cfg.n_heads, cache: cache.clone(), - max_seq_len: cfg.seq_len, }) } } diff --git a/candle-wasm-examples/llama2-c/src/worker.rs b/candle-wasm-examples/llama2-c/src/worker.rs index 79f7c1fd..3a43c57a 100644 --- a/candle-wasm-examples/llama2-c/src/worker.rs +++ b/candle-wasm-examples/llama2-c/src/worker.rs @@ -4,6 +4,7 @@ use candle::{DType, Device, IndexOp, Result, Shape, Tensor, D}; use candle_nn::{ops::softmax, VarBuilder}; use rand::{distributions::Distribution, SeedableRng}; use serde::{Deserialize, Serialize}; +use tokenizers::Tokenizer; use wasm_bindgen::prelude::*; use yew_agent::{HandlerId, Public, WorkerLink}; @@ -48,23 +49,6 @@ fn read_tensor<R: std::io::Read, S: Into<Shape>>( Ok(tensor) } -struct Tokenizer { - tokens: Vec<String>, -} - -impl Tokenizer { - fn from_reader<R: std::io::Read>(r: &mut R, c: &Config) -> Result<Self> { - let mut tokens = Vec::with_capacity(c.vocab_size); - for _token_index in 0..c.vocab_size { - let token_len = read_i32(r)?; - let mut token = vec![0u8; token_len as usize]; - r.read_exact(&mut token)?; - tokens.push(String::from_utf8_lossy(&token).into_owned()) - } - Ok(Self { tokens }) - } -} - struct Model { cache: Cache, config: Config, @@ -129,8 +113,10 @@ impl Model { let next_token = logits_processor.sample(&logits)?; tokens.push(next_token); - let token = self.tokenizer.tokens[next_token as usize].clone(); - link.respond(id, Ok(WorkerOutput::Generated(token))); + if let Some(text) = self.tokenizer.id_to_token(next_token) { + let text = text.replace('▁', " ").replace("<0x0A>", "\n"); + link.respond(id, Ok(WorkerOutput::Generated(text))); + } } Ok(()) } @@ -282,8 +268,8 @@ impl Model { let vb = weights.var_builder(&config, &dev)?; let cache = Cache::new(true, &config, vb.pp("rot"))?; let llama = Llama::load(vb, &cache, &config)?; - let mut tokenizer = std::io::Cursor::new(md.tokenizer); - let tokenizer = Tokenizer::from_reader(&mut tokenizer, &config)?; + let tokenizer = + Tokenizer::from_bytes(&md.tokenizer).map_err(|m| candle::Error::Msg(m.to_string()))?; Ok(Self { cache, config, |