summaryrefslogtreecommitdiff
path: root/candle-wasm-examples/llama2-c
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-24 12:36:02 +0100
committerGitHub <noreply@github.com>2023-07-24 12:36:02 +0100
commit5a26cba7339e326eaca7a10ee99f6af948da2677 (patch)
treee7ce4f569f3d620bd73c0bbb00198031345723b2 /candle-wasm-examples/llama2-c
parent550a13a5472fd3aa3975c2453eff4bff6ac1d0bd (diff)
downloadcandle-5a26cba7339e326eaca7a10ee99f6af948da2677.tar.gz
candle-5a26cba7339e326eaca7a10ee99f6af948da2677.tar.bz2
candle-5a26cba7339e326eaca7a10ee99f6af948da2677.zip
Re-organize the wasm examples (#231)
* Move the whisper example. * More renaming. * Add llama2 as a new wasm example. * Live generation. * More of the llama wasm example. * Formatting.
Diffstat (limited to 'candle-wasm-examples/llama2-c')
-rw-r--r--candle-wasm-examples/llama2-c/Cargo.toml51
-rw-r--r--candle-wasm-examples/llama2-c/index.html17
-rw-r--r--candle-wasm-examples/llama2-c/src/app.rs188
-rw-r--r--candle-wasm-examples/llama2-c/src/bin/app.rs4
-rw-r--r--candle-wasm-examples/llama2-c/src/bin/worker.rs4
-rw-r--r--candle-wasm-examples/llama2-c/src/lib.rs30
-rw-r--r--candle-wasm-examples/llama2-c/src/model.rs321
-rw-r--r--candle-wasm-examples/llama2-c/src/worker.rs353
8 files changed, 968 insertions, 0 deletions
diff --git a/candle-wasm-examples/llama2-c/Cargo.toml b/candle-wasm-examples/llama2-c/Cargo.toml
new file mode 100644
index 00000000..22d9cfe8
--- /dev/null
+++ b/candle-wasm-examples/llama2-c/Cargo.toml
@@ -0,0 +1,51 @@
+[package]
+name = "candle-wasm-example-llama2"
+version = "0.1.0"
+edition = "2021"
+
+description = "Wasm example for the candle ML framework."
+repository = "https://github.com/LaurentMazare/candle"
+keywords = ["blas", "tensor", "machine-learning"]
+categories = ["science"]
+license = "MIT/Apache-2.0"
+readme = "README.md"
+
+[dependencies]
+candle = { path = "../../candle-core" }
+candle-nn = { path = "../../candle-nn" }
+num-traits = { workspace = true }
+
+# App crates.
+anyhow = { workspace = true }
+byteorder = { workspace = true }
+log = { workspace = true }
+rand = { workspace = true }
+serde = { workspace = true }
+serde_json = { workspace = true }
+
+# Wasm specific crates.
+getrandom = { version = "0.2", features = ["js"] }
+gloo = "0.8"
+js-sys = "0.3.64"
+wasm-bindgen = "0.2.87"
+wasm-bindgen-futures = "0.4.37"
+wasm-logger = "0.2"
+yew-agent = "0.2.0"
+yew = { version = "0.20.0", features = ["csr"] }
+
+[dependencies.web-sys]
+version = "0.3.64"
+features = [
+ 'Blob',
+ 'Document',
+ 'Element',
+ 'HtmlElement',
+ 'Node',
+ 'Window',
+ 'Request',
+ 'RequestCache',
+ 'RequestInit',
+ 'RequestMode',
+ 'Response',
+ 'Performance',
+]
diff --git a/candle-wasm-examples/llama2-c/index.html b/candle-wasm-examples/llama2-c/index.html
new file mode 100644
index 00000000..e98e1ecb
--- /dev/null
+++ b/candle-wasm-examples/llama2-c/index.html
@@ -0,0 +1,17 @@
+<!DOCTYPE html>
+<html lang="en">
+ <head>
+ <meta charset="utf-8" />
+ <title>Welcome to Candle!</title>
+
+ <link data-trunk rel="copy-file" href="tokenizer.bin" />
+ <link data-trunk rel="copy-file" href="model.bin" />
+ <link data-trunk rel="rust" href="Cargo.toml" data-bin="app" data-type="main" />
+ <link data-trunk rel="rust" href="Cargo.toml" data-bin="worker" data-type="worker" />
+
+ <link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Roboto:300,300italic,700,700italic">
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/normalize/8.0.1/normalize.css">
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/milligram/1.4.1/milligram.css">
+ </head>
+ <body></body>
+</html>
diff --git a/candle-wasm-examples/llama2-c/src/app.rs b/candle-wasm-examples/llama2-c/src/app.rs
new file mode 100644
index 00000000..460ac053
--- /dev/null
+++ b/candle-wasm-examples/llama2-c/src/app.rs
@@ -0,0 +1,188 @@
+use crate::console_log;
+use crate::worker::{ModelData, Worker, WorkerInput, WorkerOutput};
+use wasm_bindgen::prelude::*;
+use wasm_bindgen_futures::JsFuture;
+use yew::{html, Component, Context, Html};
+use yew_agent::{Bridge, Bridged};
+
+async fn fetch_url(url: &str) -> Result<Vec<u8>, JsValue> {
+ use web_sys::{Request, RequestCache, RequestInit, RequestMode, Response};
+ let window = web_sys::window().ok_or("window")?;
+ let mut opts = RequestInit::new();
+ let opts = opts
+ .method("GET")
+ .mode(RequestMode::Cors)
+ .cache(RequestCache::NoCache);
+
+ let request = Request::new_with_str_and_init(url, opts)?;
+
+ let resp_value = JsFuture::from(window.fetch_with_request(&request)).await?;
+
+ // `resp_value` is a `Response` object.
+ assert!(resp_value.is_instance_of::<Response>());
+ let resp: Response = resp_value.dyn_into()?;
+ let data = JsFuture::from(resp.blob()?).await?;
+ let blob = web_sys::Blob::from(data);
+ let array_buffer = JsFuture::from(blob.array_buffer()).await?;
+ let data = js_sys::Uint8Array::new(&array_buffer).to_vec();
+ Ok(data)
+}
+
+pub enum Msg {
+ Run,
+ UpdateStatus(String),
+ SetModel(ModelData),
+ WorkerInMsg(WorkerInput),
+ WorkerOutMsg(Result<WorkerOutput, String>),
+}
+
+pub struct CurrentDecode {
+ start_time: Option<f64>,
+}
+
+pub struct App {
+ status: String,
+ generated: String,
+ current_decode: Option<CurrentDecode>,
+ worker: Box<dyn Bridge<Worker>>,
+}
+
+async fn model_data_load() -> Result<ModelData, JsValue> {
+ let tokenizer = fetch_url("tokenizer.bin").await?;
+ let model = fetch_url("model.bin").await?;
+ console_log!("{}", model.len());
+ Ok(ModelData { tokenizer, model })
+}
+
+fn performance_now() -> Option<f64> {
+ let window = web_sys::window()?;
+ let performance = window.performance()?;
+ Some(performance.now() / 1000.)
+}
+
+impl Component for App {
+ type Message = Msg;
+ type Properties = ();
+
+ fn create(ctx: &Context<Self>) -> Self {
+ let status = "loading weights".to_string();
+ let cb = {
+ let link = ctx.link().clone();
+ move |e| link.send_message(Self::Message::WorkerOutMsg(e))
+ };
+ let worker = Worker::bridge(std::rc::Rc::new(cb));
+ Self {
+ status,
+ generated: String::new(),
+ current_decode: None,
+ worker,
+ }
+ }
+
+ fn rendered(&mut self, ctx: &Context<Self>, first_render: bool) {
+ if first_render {
+ ctx.link().send_future(async {
+ match model_data_load().await {
+ Err(err) => {
+ let status = format!("{err:?}");
+ Msg::UpdateStatus(status)
+ }
+ Ok(model_data) => Msg::SetModel(model_data),
+ }
+ });
+ }
+ }
+
+ fn update(&mut self, ctx: &Context<Self>, msg: Self::Message) -> bool {
+ match msg {
+ Msg::SetModel(md) => {
+ self.status = "weights loaded succesfully!".to_string();
+ console_log!("loaded weights");
+ self.worker.send(WorkerInput::ModelData(md));
+ true
+ }
+ Msg::Run => {
+ if self.current_decode.is_some() {
+ self.status = "already generating some sample at the moment".to_string()
+ } else {
+ let start_time = performance_now();
+ self.current_decode = Some(CurrentDecode { start_time });
+ self.status = "generating...".to_string();
+ self.generated.clear();
+ ctx.link().send_message(Msg::WorkerInMsg(WorkerInput::Run))
+ }
+ true
+ }
+ Msg::WorkerOutMsg(output) => {
+ match output {
+ Ok(WorkerOutput::WeightsLoaded) => self.status = "weights loaded!".to_string(),
+ Ok(WorkerOutput::GenerationDone(Err(err))) => {
+ self.status = format!("error in worker process: {err}");
+ self.current_decode = None
+ }
+ Ok(WorkerOutput::GenerationDone(Ok(()))) => {
+ let dt = self.current_decode.as_ref().and_then(|current_decode| {
+ current_decode.start_time.and_then(|start_time| {
+ performance_now().map(|stop_time| stop_time - start_time)
+ })
+ });
+ self.status = match dt {
+ None => "generation succeeded!".to_string(),
+ Some(dt) => format!("generation succeeded in {:.2}s", dt),
+ };
+ self.current_decode = None
+ }
+ Ok(WorkerOutput::Generated(token)) => self.generated.push_str(&token),
+ Err(err) => {
+ self.status = format!("error in worker {err:?}");
+ }
+ }
+ true
+ }
+ Msg::WorkerInMsg(inp) => {
+ self.worker.send(inp);
+ true
+ }
+ Msg::UpdateStatus(status) => {
+ self.status = status;
+ true
+ }
+ }
+ }
+
+ fn view(&self, ctx: &Context<Self>) -> Html {
+ html! {
+ <div>
+ <div><p>{"Running "}
+ <a href="https://github.com/karpathy/llama2.c" target="_blank">{"llama2.c"}</a>
+ {" in the browser using rust/wasm with "}
+ <a href="https://github.com/LaurentMazare/candle" target="_blank">{"candle!"}</a>
+ </p>
+ <p>{"Once the weights have loaded, click on the run button to start generating content."}
+ </p>
+ </div>
+ <button class="button" onclick={ctx.link().callback(move |_| Msg::Run)}> { "run" }</button>
+ <br/ >
+ <h3>
+ {&self.status}
+ </h3>
+ {
+ if self.current_decode.is_some() {
+ html! { <progress id="progress-bar" aria-label="generating…"></progress> }
+ } else {
+ html! {}
+ }
+ }
+ <blockquote>
+ <p> { self.generated.chars().map(|c|
+ if c == '\r' || c == '\n' {
+ html! { <br/> }
+ } else {
+ html! { {c} }
+ }).collect::<Html>()
+ } </p>
+ </blockquote>
+ </div>
+ }
+ }
+}
diff --git a/candle-wasm-examples/llama2-c/src/bin/app.rs b/candle-wasm-examples/llama2-c/src/bin/app.rs
new file mode 100644
index 00000000..3428f6ff
--- /dev/null
+++ b/candle-wasm-examples/llama2-c/src/bin/app.rs
@@ -0,0 +1,4 @@
+fn main() {
+ wasm_logger::init(wasm_logger::Config::new(log::Level::Trace));
+ yew::Renderer::<candle_wasm_example_llama2::App>::new().render();
+}
diff --git a/candle-wasm-examples/llama2-c/src/bin/worker.rs b/candle-wasm-examples/llama2-c/src/bin/worker.rs
new file mode 100644
index 00000000..d8ca2172
--- /dev/null
+++ b/candle-wasm-examples/llama2-c/src/bin/worker.rs
@@ -0,0 +1,4 @@
+use yew_agent::PublicWorker;
+fn main() {
+ candle_wasm_example_llama2::Worker::register();
+}
diff --git a/candle-wasm-examples/llama2-c/src/lib.rs b/candle-wasm-examples/llama2-c/src/lib.rs
new file mode 100644
index 00000000..61154d04
--- /dev/null
+++ b/candle-wasm-examples/llama2-c/src/lib.rs
@@ -0,0 +1,30 @@
+#![allow(dead_code)]
+
+pub const WITH_TIMER: bool = true;
+
+struct Timer {
+ label: &'static str,
+}
+
+impl Timer {
+ fn new(label: &'static str) -> Self {
+ if WITH_TIMER {
+ web_sys::console::time_with_label(label);
+ }
+ Self { label }
+ }
+}
+
+impl Drop for Timer {
+ fn drop(&mut self) {
+ if WITH_TIMER {
+ web_sys::console::time_end_with_label(self.label)
+ }
+ }
+}
+
+mod app;
+mod model;
+mod worker;
+pub use app::App;
+pub use worker::Worker;
diff --git a/candle-wasm-examples/llama2-c/src/model.rs b/candle-wasm-examples/llama2-c/src/model.rs
new file mode 100644
index 00000000..13f939db
--- /dev/null
+++ b/candle-wasm-examples/llama2-c/src/model.rs
@@ -0,0 +1,321 @@
+use candle::{DType, Device, IndexOp, Result, Tensor, D};
+use candle_nn::{Embedding, Linear, VarBuilder};
+use std::collections::HashMap;
+use std::sync::{Arc, Mutex};
+
+#[derive(Debug, Clone)]
+pub struct Config {
+ pub dim: usize, // transformer dimension
+ pub hidden_dim: usize, // for ffn layers
+ pub n_layers: usize, // number of layers
+ pub n_heads: usize, // number of query heads
+ pub n_kv_heads: usize, // number of key/value heads (can be < query heads because of multiquery)
+ pub vocab_size: usize, // vocabulary size, usually 256 (byte-level)
+ pub seq_len: usize, // max sequence length
+ pub norm_eps: f64,
+}
+
+#[derive(Clone)]
+pub struct Cache {
+ masks: Arc<Mutex<HashMap<usize, Tensor>>>,
+ pub use_kv_cache: bool,
+ #[allow(clippy::type_complexity)]
+ kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
+ cos: Tensor,
+ sin: Tensor,
+ device: Device,
+}
+
+impl Cache {
+ pub fn new(use_kv_cache: bool, cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let freq_cis_real = vb.get((cfg.seq_len, cfg.head_size() / 2), "freq_cis_real")?;
+ let freq_cis_imag = vb.get((cfg.seq_len, cfg.head_size() / 2), "freq_cis_imag")?;
+ let cos = freq_cis_real.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?;
+ let sin = freq_cis_imag.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?;
+ Ok(Self {
+ masks: Arc::new(Mutex::new(HashMap::new())),
+ use_kv_cache,
+ kvs: Arc::new(Mutex::new(vec![None; cfg.n_layers])),
+ cos,
+ sin,
+ device: vb.device().clone(),
+ })
+ }
+
+ fn mask(&self, t: usize) -> Result<Tensor> {
+ let mut masks = self.masks.lock().unwrap();
+ if let Some(mask) = masks.get(&t) {
+ Ok(mask.clone())
+ } else {
+ // TODO: If we support bool or u8 tensors, this would be better.
+ let mask: Vec<_> = (0..t)
+ .flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
+ .collect();
+ let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
+ masks.insert(t, mask.clone());
+ Ok(mask)
+ }
+ }
+}
+
+fn silu(xs: &Tensor) -> Result<Tensor> {
+ xs / (xs.neg()?.exp()? + 1.0)?
+}
+
+fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
+ let weight = vb.get((size2, size1), "weight")?;
+ Ok(Linear::new(weight, None))
+}
+
+fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
+ let embeddings = vb.get((cfg.vocab_size, cfg.dim), "weight")?;
+ Ok(Embedding::new(embeddings, cfg.dim))
+}
+
+struct RmsNorm {
+ scale: Tensor,
+ eps: f64,
+}
+
+impl RmsNorm {
+ fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
+ let scale = vb.get(size, "weight")?;
+ Ok(Self { scale, eps })
+ }
+
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let (b_sz, seq_len, hidden_size) = x.dims3()?;
+ let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
+ let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
+ let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
+ let size = self.scale.dims1()?;
+ let scale = self
+ .scale
+ .to_dtype(DType::F32)?
+ .broadcast_as((b_sz, seq_len, size))?;
+ let x = (scale * x_normed)?;
+ Ok(x)
+ }
+}
+
+struct CausalSelfAttention {
+ q_proj: Linear,
+ k_proj: Linear,
+ v_proj: Linear,
+ o_proj: Linear,
+ n_head: usize,
+ n_key_value_head: usize,
+ head_dim: usize,
+ cache: Cache,
+ max_seq_len: usize,
+}
+
+impl CausalSelfAttention {
+ fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
+ let (b_sz, seq_len, h, n_embd) = x.dims4()?;
+ let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
+ let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
+ let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
+ let sin = sin.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
+ let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;
+ let x0 = x.narrow(D::Minus1, 0, 1)?;
+ let x1 = x.narrow(D::Minus1, 1, 1)?;
+ let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
+ let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
+ let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?;
+ Ok(rope)
+ }
+
+ fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
+ let (b_sz, seq_len, n_embd) = x.dims3()?;
+ let q = self.q_proj.forward(x)?;
+ let k = self.k_proj.forward(x)?;
+ let v = self.v_proj.forward(x)?;
+
+ let q = q.reshape((b_sz, seq_len, self.n_head, self.head_dim))?;
+ let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
+ let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
+
+ let q = self.apply_rotary_emb(&q, index_pos)?;
+ let mut k = self.apply_rotary_emb(&k, index_pos)?;
+
+ if self.cache.use_kv_cache {
+ let mut cache = self.cache.kvs.lock().unwrap();
+ if let Some((cache_k, cache_v)) = &cache[block_idx] {
+ k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
+ v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?;
+ }
+ cache[block_idx] = Some((k.clone(), v.clone()))
+ }
+
+ let k = self.repeat_kv(k)?;
+ let v = self.repeat_kv(v)?;
+
+ let q = q.transpose(1, 2)?.contiguous()?;
+ let k = k.transpose(1, 2)?.contiguous()?;
+ let v = v.transpose(1, 2)?.contiguous()?;
+
+ let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
+ let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
+ let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
+ let att = att.softmax(D::Minus1)?;
+ // Convert to contiguous as matmul doesn't support strided vs for now.
+ let y = att.matmul(&v.contiguous()?)?;
+ let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
+ let y = self.o_proj.forward(&y)?;
+ Ok(y)
+ }
+
+ fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
+ let n_rep = self.n_head / self.n_key_value_head;
+ if n_rep == 1 {
+ Ok(x)
+ } else {
+ let (b_sz, seq_len, n_kv_head, head_dim) = x.dims4()?;
+ let x = x
+ .unsqueeze(3)?
+ .expand((b_sz, seq_len, n_kv_head, n_rep, head_dim))?
+ .reshape((b_sz, seq_len, n_kv_head * n_rep, head_dim))?;
+ Ok(x)
+ }
+ }
+
+ fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
+ let size_in = cfg.dim;
+ let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads;
+ let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads;
+ let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?;
+ let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?;
+ let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?;
+ let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?;
+ Ok(Self {
+ q_proj,
+ k_proj,
+ v_proj,
+ o_proj,
+ n_head: cfg.n_heads,
+ n_key_value_head: cfg.n_kv_heads,
+ head_dim: cfg.dim / cfg.n_heads,
+ cache: cache.clone(),
+ max_seq_len: cfg.seq_len,
+ })
+ }
+}
+
+fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
+ let shape = mask.shape();
+ let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
+ let m = mask.where_cond(&on_true, on_false)?;
+ Ok(m)
+}
+
+struct Mlp {
+ c_fc1: Linear,
+ c_fc2: Linear,
+ c_proj: Linear,
+}
+
+impl Mlp {
+ fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self {
+ Self {
+ c_fc1,
+ c_fc2,
+ c_proj,
+ }
+ }
+
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
+ self.c_proj.forward(&x)
+ }
+
+ fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let h_size = cfg.dim;
+ let i_size = cfg.hidden_dim;
+ let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?;
+ let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?;
+ let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?;
+ Ok(Self::new(c_fc1, c_fc2, c_proj))
+ }
+}
+
+struct Block {
+ rms_1: RmsNorm,
+ attn: CausalSelfAttention,
+ rms_2: RmsNorm,
+ mlp: Mlp,
+}
+
+impl Block {
+ fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
+ Self {
+ rms_1,
+ attn,
+ rms_2,
+ mlp,
+ }
+ }
+
+ fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
+ let residual = x;
+ let x = self.rms_1.forward(x)?;
+ let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
+ let residual = &x;
+ let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
+ Ok(x)
+ }
+
+ fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
+ let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
+ let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
+ let input_layernorm = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
+ let post_attention_layernorm =
+ RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?;
+ Ok(Self::new(
+ input_layernorm,
+ attn,
+ post_attention_layernorm,
+ mlp,
+ ))
+ }
+}
+
+pub struct Llama {
+ wte: Embedding,
+ blocks: Vec<Block>,
+ ln_f: RmsNorm,
+ lm_head: Linear,
+}
+
+impl Llama {
+ fn new(wte: Embedding, blocks: Vec<Block>, ln_f: RmsNorm, lm_head: Linear) -> Self {
+ Self {
+ wte,
+ blocks,
+ ln_f,
+ lm_head,
+ }
+ }
+
+ pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
+ let (_b_sz, seq_len) = x.dims2()?;
+ let mut x = self.wte.forward(x)?;
+ for (block_idx, block) in self.blocks.iter().enumerate() {
+ x = block.forward(&x, index_pos, block_idx)?;
+ }
+ let x = self.ln_f.forward(&x)?;
+ let x = x.i((.., seq_len - 1, ..))?;
+ let logits = self.lm_head.forward(&x)?;
+ logits.to_dtype(DType::F32)
+ }
+
+ pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
+ let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
+ let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
+ let norm = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
+ let blocks: Vec<_> = (0..cfg.n_layers)
+ .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap())
+ .collect();
+ Ok(Self::new(wte, blocks, norm, lm_head))
+ }
+}
diff --git a/candle-wasm-examples/llama2-c/src/worker.rs b/candle-wasm-examples/llama2-c/src/worker.rs
new file mode 100644
index 00000000..9b0351d6
--- /dev/null
+++ b/candle-wasm-examples/llama2-c/src/worker.rs
@@ -0,0 +1,353 @@
+use crate::model::{Cache, Config, Llama};
+use byteorder::{LittleEndian, ReadBytesExt};
+use candle::{DType, Device, IndexOp, Result, Shape, Tensor, D};
+use candle_nn::VarBuilder;
+use rand::{distributions::Distribution, SeedableRng};
+use serde::{Deserialize, Serialize};
+use wasm_bindgen::prelude::*;
+use yew_agent::{HandlerId, Public, WorkerLink};
+
+#[wasm_bindgen]
+extern "C" {
+ // Use `js_namespace` here to bind `console.log(..)` instead of just
+ // `log(..)`
+ #[wasm_bindgen(js_namespace = console)]
+ pub fn log(s: &str);
+}
+
+#[macro_export]
+macro_rules! console_log {
+ // Note that this is using the `log` function imported above during
+ // `bare_bones`
+ ($($t:tt)*) => ($crate::worker::log(&format_args!($($t)*).to_string()))
+}
+
+// Communication to the worker happens through bincode, the model weights and configs are fetched
+// on the main thread and transfered via the following structure.
+#[derive(Serialize, Deserialize)]
+pub struct ModelData {
+ pub tokenizer: Vec<u8>,
+ pub model: Vec<u8>,
+}
+
+fn read_i32<R: std::io::Read>(r: &mut R) -> Result<i32> {
+ let mut buf = [0u8; 4];
+ r.read_exact(&mut buf)?;
+ Ok(i32::from_le_bytes(buf))
+}
+
+fn read_tensor<R: std::io::Read, S: Into<Shape>>(
+ r: &mut R,
+ shape: S,
+ dev: &Device,
+) -> Result<Tensor> {
+ let shape = shape.into();
+ let mut data_t = vec![0f32; shape.elem_count()];
+ r.read_f32_into::<LittleEndian>(&mut data_t)?;
+ let tensor = Tensor::from_vec(data_t, shape, dev)?;
+ Ok(tensor)
+}
+
+struct Tokenizer {
+ tokens: Vec<String>,
+}
+
+impl Tokenizer {
+ fn from_reader<R: std::io::Read>(r: &mut R, c: &Config) -> Result<Self> {
+ let mut tokens = Vec::with_capacity(c.vocab_size);
+ for _token_index in 0..c.vocab_size {
+ let token_len = read_i32(r)?;
+ let mut token = vec![0u8; token_len as usize];
+ r.read_exact(&mut token)?;
+ tokens.push(String::from_utf8_lossy(&token).into_owned())
+ }
+ Ok(Self { tokens })
+ }
+}
+
+struct Model {
+ cache: Cache,
+ config: Config,
+ llama: Llama,
+ tokenizer: Tokenizer,
+}
+
+pub struct LogitsProcessor {
+ rng: rand::rngs::StdRng,
+ temperature: Option<f64>,
+}
+
+impl LogitsProcessor {
+ pub fn new(seed: u64, temperature: Option<f64>) -> Self {
+ Self {
+ rng: rand::rngs::StdRng::seed_from_u64(seed),
+ temperature,
+ }
+ }
+
+ pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
+ let logits = logits.to_dtype(DType::F32)?;
+ let next_token = if let Some(temperature) = self.temperature {
+ let prs = (&logits / temperature)?.softmax(D::Minus1)?;
+ let prs: Vec<f32> = prs.to_vec1()?;
+ let distr =
+ rand::distributions::WeightedIndex::new(prs).map_err(candle::Error::wrap)?;
+ distr.sample(&mut self.rng) as u32
+ } else {
+ let logits_v: Vec<f32> = logits.to_vec1()?;
+ logits_v
+ .iter()
+ .enumerate()
+ .max_by(|(_, u), (_, v)| u.total_cmp(v))
+ .map(|(i, _)| i as u32)
+ .unwrap()
+ };
+ Ok(next_token)
+ }
+}
+
+impl Model {
+ fn run(&self, link: &WorkerLink<Worker>, id: HandlerId) -> Result<()> {
+ let dev = Device::Cpu;
+ let mut logits_processor = LogitsProcessor::new(299792458, None);
+ let mut index_pos = 0;
+ let mut tokens = vec![1u32];
+
+ for index in 0..self.config.seq_len - 10 {
+ let context_size = if self.cache.use_kv_cache && index > 0 {
+ 1
+ } else {
+ tokens.len()
+ };
+ let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
+ let input = Tensor::new(ctxt, &dev)?.unsqueeze(0)?;
+ let logits = self.llama.forward(&input, index_pos)?;
+ let logits = logits.squeeze(0)?;
+ index_pos += ctxt.len();
+
+ let next_token = logits_processor.sample(&logits)?;
+ tokens.push(next_token);
+ let token = self.tokenizer.tokens[next_token as usize].clone();
+ link.respond(id, Ok(WorkerOutput::Generated(token)));
+ }
+ Ok(())
+ }
+}
+
+impl Config {
+ fn from_reader<R: std::io::Read>(r: &mut R) -> Result<Self> {
+ let dim = read_i32(r)? as usize;
+ let hidden_dim = read_i32(r)? as usize;
+ let n_layers = read_i32(r)? as usize;
+ let n_heads = read_i32(r)? as usize;
+ let n_kv_heads = read_i32(r)? as usize;
+ let vocab_size = read_i32(r)? as usize;
+ let seq_len = read_i32(r)? as usize;
+ Ok(Self {
+ dim,
+ hidden_dim,
+ n_layers,
+ n_heads,
+ n_kv_heads,
+ vocab_size,
+ seq_len,
+ norm_eps: 1e-5,
+ })
+ }
+
+ pub fn head_size(&self) -> usize {
+ self.dim / self.n_heads
+ }
+}
+
+struct TransformerWeights {
+ // token embedding table
+ token_embedding_table: Tensor, // (vocab_size, dim)
+ // weights for rmsnorms
+ rms_att_weight: Tensor, // (layer, dim) rmsnorm weights
+ rms_ffn_weight: Tensor, // (layer, dim)
+ // weights for matmuls
+ wq: Tensor, // (layer, dim, dim)
+ wk: Tensor, // (layer, dim, dim)
+ wv: Tensor, // (layer, dim, dim)
+ wo: Tensor, // (layer, dim, dim)
+ // weights for ffn
+ w1: Tensor, // (layer, hidden_dim, dim)
+ w2: Tensor, // (layer, dim, hidden_dim)
+ w3: Tensor, // (layer, hidden_dim, dim)
+ // final rmsnorm
+ rms_final_weight: Tensor, // (dim,)
+ // freq_cis for RoPE relatively positional embeddings
+ freq_cis_real: Tensor, // (seq_len, head_size/2)
+ freq_cis_imag: Tensor, // (seq_len, head_size/2)
+}
+
+impl TransformerWeights {
+ fn from_reader<R: std::io::Read>(r: &mut R, c: &Config, dev: &Device) -> Result<Self> {
+ let token_embedding_table = read_tensor(r, (c.vocab_size, c.dim), dev)?;
+ let rms_att_weight = read_tensor(r, (c.n_layers, c.dim), dev)?;
+ let wq = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?;
+ let wk = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?;
+ let wv = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?;
+ let wo = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?;
+ let rms_ffn_weight = read_tensor(r, (c.n_layers, c.dim), dev)?;
+ let w1 = read_tensor(r, (c.n_layers, c.hidden_dim, c.dim), dev)?;
+ let w2 = read_tensor(r, (c.n_layers, c.dim, c.hidden_dim), dev)?;
+ let w3 = read_tensor(r, (c.n_layers, c.hidden_dim, c.dim), dev)?;
+ let rms_final_weight = read_tensor(r, c.dim, dev)?;
+ let head_size = c.head_size();
+ let freq_cis_real = read_tensor(r, (c.seq_len, head_size / 2), dev)?;
+ let freq_cis_imag = read_tensor(r, (c.seq_len, head_size / 2), dev)?;
+ Ok(Self {
+ token_embedding_table,
+ rms_att_weight,
+ wq,
+ wk,
+ wv,
+ wo,
+ rms_ffn_weight,
+ w1,
+ w2,
+ w3,
+ rms_final_weight,
+ freq_cis_real,
+ freq_cis_imag,
+ })
+ }
+
+ fn var_builder(&self, cfg: &Config, device: &Device) -> Result<VarBuilder> {
+ let mut ws = std::collections::HashMap::new();
+ let mut insert = |name: &str, t: Tensor| {
+ ws.insert(name.to_string(), t);
+ };
+ insert("rot.freq_cis_real", self.freq_cis_real.clone());
+ insert("rot.freq_cis_imag", self.freq_cis_imag.clone());
+ insert(
+ "model.embed_tokens.weight",
+ self.token_embedding_table.clone(),
+ );
+ insert("lm_head.weight", self.token_embedding_table.clone());
+ insert("model.norm.weight", self.rms_final_weight.clone());
+ for layer in 0..cfg.n_layers {
+ ws.insert(
+ format!("model.layers.{layer}.self_attn.q_proj.weight"),
+ self.wq.i(layer)?,
+ );
+ ws.insert(
+ format!("model.layers.{layer}.self_attn.k_proj.weight"),
+ self.wk.i(layer)?,
+ );
+ ws.insert(
+ format!("model.layers.{layer}.self_attn.v_proj.weight"),
+ self.wv.i(layer)?,
+ );
+ ws.insert(
+ format!("model.layers.{layer}.self_attn.o_proj.weight"),
+ self.wo.i(layer)?,
+ );
+ ws.insert(
+ format!("model.layers.{layer}.mlp.gate_proj.weight"),
+ self.w1.i(layer)?,
+ );
+ ws.insert(
+ format!("model.layers.{layer}.mlp.down_proj.weight"),
+ self.w2.i(layer)?,
+ );
+ ws.insert(
+ format!("model.layers.{layer}.mlp.up_proj.weight"),
+ self.w3.i(layer)?,
+ );
+ ws.insert(
+ format!("model.layers.{layer}.input_layernorm.weight"),
+ self.rms_att_weight.i(layer)?,
+ );
+ ws.insert(
+ format!("model.layers.{layer}.post_attention_layernorm.weight"),
+ self.rms_ffn_weight.i(layer)?,
+ );
+ }
+ let vb = VarBuilder::from_tensors(ws, DType::F32, device);
+ Ok(vb)
+ }
+}
+
+impl Model {
+ fn load(md: ModelData) -> Result<Self> {
+ let dev = Device::Cpu;
+ let mut model = std::io::Cursor::new(md.model);
+ let config = Config::from_reader(&mut model)?;
+ let weights = TransformerWeights::from_reader(&mut model, &config, &dev)?;
+ let vb = weights.var_builder(&config, &dev)?;
+ let cache = Cache::new(true, &config, vb.pp("rot"))?;
+ let llama = Llama::load(vb, &cache, &config)?;
+ let mut tokenizer = std::io::Cursor::new(md.tokenizer);
+ let tokenizer = Tokenizer::from_reader(&mut tokenizer, &config)?;
+ Ok(Self {
+ cache,
+ config,
+ llama,
+ tokenizer,
+ })
+ }
+}
+
+pub struct Worker {
+ link: WorkerLink<Self>,
+ model: Option<Model>,
+}
+
+#[derive(Serialize, Deserialize)]
+pub enum WorkerInput {
+ ModelData(ModelData),
+ Run,
+}
+
+#[derive(Serialize, Deserialize)]
+pub enum WorkerOutput {
+ Generated(String),
+ GenerationDone(std::result::Result<(), String>),
+ WeightsLoaded,
+}
+
+impl yew_agent::Worker for Worker {
+ type Input = WorkerInput;
+ type Message = ();
+ type Output = std::result::Result<WorkerOutput, String>;
+ type Reach = Public<Self>;
+
+ fn create(link: WorkerLink<Self>) -> Self {
+ Self { link, model: None }
+ }
+
+ fn update(&mut self, _msg: Self::Message) {
+ // no messaging
+ }
+
+ fn handle_input(&mut self, msg: Self::Input, id: HandlerId) {
+ let output = match msg {
+ WorkerInput::ModelData(md) => match Model::load(md) {
+ Ok(model) => {
+ self.model = Some(model);
+ Ok(WorkerOutput::WeightsLoaded)
+ }
+ Err(err) => Err(format!("model creation error {err:?}")),
+ },
+ WorkerInput::Run => match &self.model {
+ None => Err("model has not been set yet".to_string()),
+ Some(model) => {
+ let result = model.run(&self.link, id).map_err(|e| e.to_string());
+ Ok(WorkerOutput::GenerationDone(result))
+ }
+ },
+ };
+ self.link.respond(id, output);
+ }
+
+ fn name_of_resource() -> &'static str {
+ "worker.js"
+ }
+
+ fn resource_path_is_relative() -> bool {
+ true
+ }
+}