summaryrefslogtreecommitdiff
path: root/candle-wasm-examples
diff options
context:
space:
mode:
Diffstat (limited to 'candle-wasm-examples')
-rw-r--r--candle-wasm-examples/llama2-c/Cargo.toml51
-rw-r--r--candle-wasm-examples/llama2-c/index.html17
-rw-r--r--candle-wasm-examples/llama2-c/src/app.rs188
-rw-r--r--candle-wasm-examples/llama2-c/src/bin/app.rs4
-rw-r--r--candle-wasm-examples/llama2-c/src/bin/worker.rs4
-rw-r--r--candle-wasm-examples/llama2-c/src/lib.rs30
-rw-r--r--candle-wasm-examples/llama2-c/src/model.rs321
-rw-r--r--candle-wasm-examples/llama2-c/src/worker.rs353
-rw-r--r--candle-wasm-examples/whisper/Cargo.toml52
-rw-r--r--candle-wasm-examples/whisper/index.html24
-rw-r--r--candle-wasm-examples/whisper/main.js6
-rw-r--r--candle-wasm-examples/whisper/src/app.rs238
-rw-r--r--candle-wasm-examples/whisper/src/audio.rs217
-rw-r--r--candle-wasm-examples/whisper/src/bin/app.rs4
-rw-r--r--candle-wasm-examples/whisper/src/bin/worker.rs4
-rw-r--r--candle-wasm-examples/whisper/src/lib.rs31
-rw-r--r--candle-wasm-examples/whisper/src/model.rs421
-rw-r--r--candle-wasm-examples/whisper/src/worker.rs345
18 files changed, 2310 insertions, 0 deletions
diff --git a/candle-wasm-examples/llama2-c/Cargo.toml b/candle-wasm-examples/llama2-c/Cargo.toml
new file mode 100644
index 00000000..22d9cfe8
--- /dev/null
+++ b/candle-wasm-examples/llama2-c/Cargo.toml
@@ -0,0 +1,51 @@
+[package]
+name = "candle-wasm-example-llama2"
+version = "0.1.0"
+edition = "2021"
+
+description = "Wasm example for the candle ML framework."
+repository = "https://github.com/LaurentMazare/candle"
+keywords = ["blas", "tensor", "machine-learning"]
+categories = ["science"]
+license = "MIT/Apache-2.0"
+readme = "README.md"
+
+[dependencies]
+candle = { path = "../../candle-core" }
+candle-nn = { path = "../../candle-nn" }
+num-traits = { workspace = true }
+
+# App crates.
+anyhow = { workspace = true }
+byteorder = { workspace = true }
+log = { workspace = true }
+rand = { workspace = true }
+serde = { workspace = true }
+serde_json = { workspace = true }
+
+# Wasm specific crates.
+getrandom = { version = "0.2", features = ["js"] }
+gloo = "0.8"
+js-sys = "0.3.64"
+wasm-bindgen = "0.2.87"
+wasm-bindgen-futures = "0.4.37"
+wasm-logger = "0.2"
+yew-agent = "0.2.0"
+yew = { version = "0.20.0", features = ["csr"] }
+
+[dependencies.web-sys]
+version = "0.3.64"
+features = [
+ 'Blob',
+ 'Document',
+ 'Element',
+ 'HtmlElement',
+ 'Node',
+ 'Window',
+ 'Request',
+ 'RequestCache',
+ 'RequestInit',
+ 'RequestMode',
+ 'Response',
+ 'Performance',
+]
diff --git a/candle-wasm-examples/llama2-c/index.html b/candle-wasm-examples/llama2-c/index.html
new file mode 100644
index 00000000..e98e1ecb
--- /dev/null
+++ b/candle-wasm-examples/llama2-c/index.html
@@ -0,0 +1,17 @@
+<!DOCTYPE html>
+<html lang="en">
+ <head>
+ <meta charset="utf-8" />
+ <title>Welcome to Candle!</title>
+
+ <link data-trunk rel="copy-file" href="tokenizer.bin" />
+ <link data-trunk rel="copy-file" href="model.bin" />
+ <link data-trunk rel="rust" href="Cargo.toml" data-bin="app" data-type="main" />
+ <link data-trunk rel="rust" href="Cargo.toml" data-bin="worker" data-type="worker" />
+
+ <link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Roboto:300,300italic,700,700italic">
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/normalize/8.0.1/normalize.css">
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/milligram/1.4.1/milligram.css">
+ </head>
+ <body></body>
+</html>
diff --git a/candle-wasm-examples/llama2-c/src/app.rs b/candle-wasm-examples/llama2-c/src/app.rs
new file mode 100644
index 00000000..460ac053
--- /dev/null
+++ b/candle-wasm-examples/llama2-c/src/app.rs
@@ -0,0 +1,188 @@
+use crate::console_log;
+use crate::worker::{ModelData, Worker, WorkerInput, WorkerOutput};
+use wasm_bindgen::prelude::*;
+use wasm_bindgen_futures::JsFuture;
+use yew::{html, Component, Context, Html};
+use yew_agent::{Bridge, Bridged};
+
+async fn fetch_url(url: &str) -> Result<Vec<u8>, JsValue> {
+ use web_sys::{Request, RequestCache, RequestInit, RequestMode, Response};
+ let window = web_sys::window().ok_or("window")?;
+ let mut opts = RequestInit::new();
+ let opts = opts
+ .method("GET")
+ .mode(RequestMode::Cors)
+ .cache(RequestCache::NoCache);
+
+ let request = Request::new_with_str_and_init(url, opts)?;
+
+ let resp_value = JsFuture::from(window.fetch_with_request(&request)).await?;
+
+ // `resp_value` is a `Response` object.
+ assert!(resp_value.is_instance_of::<Response>());
+ let resp: Response = resp_value.dyn_into()?;
+ let data = JsFuture::from(resp.blob()?).await?;
+ let blob = web_sys::Blob::from(data);
+ let array_buffer = JsFuture::from(blob.array_buffer()).await?;
+ let data = js_sys::Uint8Array::new(&array_buffer).to_vec();
+ Ok(data)
+}
+
+pub enum Msg {
+ Run,
+ UpdateStatus(String),
+ SetModel(ModelData),
+ WorkerInMsg(WorkerInput),
+ WorkerOutMsg(Result<WorkerOutput, String>),
+}
+
+pub struct CurrentDecode {
+ start_time: Option<f64>,
+}
+
+pub struct App {
+ status: String,
+ generated: String,
+ current_decode: Option<CurrentDecode>,
+ worker: Box<dyn Bridge<Worker>>,
+}
+
+async fn model_data_load() -> Result<ModelData, JsValue> {
+ let tokenizer = fetch_url("tokenizer.bin").await?;
+ let model = fetch_url("model.bin").await?;
+ console_log!("{}", model.len());
+ Ok(ModelData { tokenizer, model })
+}
+
+fn performance_now() -> Option<f64> {
+ let window = web_sys::window()?;
+ let performance = window.performance()?;
+ Some(performance.now() / 1000.)
+}
+
+impl Component for App {
+ type Message = Msg;
+ type Properties = ();
+
+ fn create(ctx: &Context<Self>) -> Self {
+ let status = "loading weights".to_string();
+ let cb = {
+ let link = ctx.link().clone();
+ move |e| link.send_message(Self::Message::WorkerOutMsg(e))
+ };
+ let worker = Worker::bridge(std::rc::Rc::new(cb));
+ Self {
+ status,
+ generated: String::new(),
+ current_decode: None,
+ worker,
+ }
+ }
+
+ fn rendered(&mut self, ctx: &Context<Self>, first_render: bool) {
+ if first_render {
+ ctx.link().send_future(async {
+ match model_data_load().await {
+ Err(err) => {
+ let status = format!("{err:?}");
+ Msg::UpdateStatus(status)
+ }
+ Ok(model_data) => Msg::SetModel(model_data),
+ }
+ });
+ }
+ }
+
+ fn update(&mut self, ctx: &Context<Self>, msg: Self::Message) -> bool {
+ match msg {
+ Msg::SetModel(md) => {
+ self.status = "weights loaded succesfully!".to_string();
+ console_log!("loaded weights");
+ self.worker.send(WorkerInput::ModelData(md));
+ true
+ }
+ Msg::Run => {
+ if self.current_decode.is_some() {
+ self.status = "already generating some sample at the moment".to_string()
+ } else {
+ let start_time = performance_now();
+ self.current_decode = Some(CurrentDecode { start_time });
+ self.status = "generating...".to_string();
+ self.generated.clear();
+ ctx.link().send_message(Msg::WorkerInMsg(WorkerInput::Run))
+ }
+ true
+ }
+ Msg::WorkerOutMsg(output) => {
+ match output {
+ Ok(WorkerOutput::WeightsLoaded) => self.status = "weights loaded!".to_string(),
+ Ok(WorkerOutput::GenerationDone(Err(err))) => {
+ self.status = format!("error in worker process: {err}");
+ self.current_decode = None
+ }
+ Ok(WorkerOutput::GenerationDone(Ok(()))) => {
+ let dt = self.current_decode.as_ref().and_then(|current_decode| {
+ current_decode.start_time.and_then(|start_time| {
+ performance_now().map(|stop_time| stop_time - start_time)
+ })
+ });
+ self.status = match dt {
+ None => "generation succeeded!".to_string(),
+ Some(dt) => format!("generation succeeded in {:.2}s", dt),
+ };
+ self.current_decode = None
+ }
+ Ok(WorkerOutput::Generated(token)) => self.generated.push_str(&token),
+ Err(err) => {
+ self.status = format!("error in worker {err:?}");
+ }
+ }
+ true
+ }
+ Msg::WorkerInMsg(inp) => {
+ self.worker.send(inp);
+ true
+ }
+ Msg::UpdateStatus(status) => {
+ self.status = status;
+ true
+ }
+ }
+ }
+
+ fn view(&self, ctx: &Context<Self>) -> Html {
+ html! {
+ <div>
+ <div><p>{"Running "}
+ <a href="https://github.com/karpathy/llama2.c" target="_blank">{"llama2.c"}</a>
+ {" in the browser using rust/wasm with "}
+ <a href="https://github.com/LaurentMazare/candle" target="_blank">{"candle!"}</a>
+ </p>
+ <p>{"Once the weights have loaded, click on the run button to start generating content."}
+ </p>
+ </div>
+ <button class="button" onclick={ctx.link().callback(move |_| Msg::Run)}> { "run" }</button>
+ <br/ >
+ <h3>
+ {&self.status}
+ </h3>
+ {
+ if self.current_decode.is_some() {
+ html! { <progress id="progress-bar" aria-label="generating…"></progress> }
+ } else {
+ html! {}
+ }
+ }
+ <blockquote>
+ <p> { self.generated.chars().map(|c|
+ if c == '\r' || c == '\n' {
+ html! { <br/> }
+ } else {
+ html! { {c} }
+ }).collect::<Html>()
+ } </p>
+ </blockquote>
+ </div>
+ }
+ }
+}
diff --git a/candle-wasm-examples/llama2-c/src/bin/app.rs b/candle-wasm-examples/llama2-c/src/bin/app.rs
new file mode 100644
index 00000000..3428f6ff
--- /dev/null
+++ b/candle-wasm-examples/llama2-c/src/bin/app.rs
@@ -0,0 +1,4 @@
+fn main() {
+ wasm_logger::init(wasm_logger::Config::new(log::Level::Trace));
+ yew::Renderer::<candle_wasm_example_llama2::App>::new().render();
+}
diff --git a/candle-wasm-examples/llama2-c/src/bin/worker.rs b/candle-wasm-examples/llama2-c/src/bin/worker.rs
new file mode 100644
index 00000000..d8ca2172
--- /dev/null
+++ b/candle-wasm-examples/llama2-c/src/bin/worker.rs
@@ -0,0 +1,4 @@
+use yew_agent::PublicWorker;
+fn main() {
+ candle_wasm_example_llama2::Worker::register();
+}
diff --git a/candle-wasm-examples/llama2-c/src/lib.rs b/candle-wasm-examples/llama2-c/src/lib.rs
new file mode 100644
index 00000000..61154d04
--- /dev/null
+++ b/candle-wasm-examples/llama2-c/src/lib.rs
@@ -0,0 +1,30 @@
+#![allow(dead_code)]
+
+pub const WITH_TIMER: bool = true;
+
+struct Timer {
+ label: &'static str,
+}
+
+impl Timer {
+ fn new(label: &'static str) -> Self {
+ if WITH_TIMER {
+ web_sys::console::time_with_label(label);
+ }
+ Self { label }
+ }
+}
+
+impl Drop for Timer {
+ fn drop(&mut self) {
+ if WITH_TIMER {
+ web_sys::console::time_end_with_label(self.label)
+ }
+ }
+}
+
+mod app;
+mod model;
+mod worker;
+pub use app::App;
+pub use worker::Worker;
diff --git a/candle-wasm-examples/llama2-c/src/model.rs b/candle-wasm-examples/llama2-c/src/model.rs
new file mode 100644
index 00000000..13f939db
--- /dev/null
+++ b/candle-wasm-examples/llama2-c/src/model.rs
@@ -0,0 +1,321 @@
+use candle::{DType, Device, IndexOp, Result, Tensor, D};
+use candle_nn::{Embedding, Linear, VarBuilder};
+use std::collections::HashMap;
+use std::sync::{Arc, Mutex};
+
+#[derive(Debug, Clone)]
+pub struct Config {
+ pub dim: usize, // transformer dimension
+ pub hidden_dim: usize, // for ffn layers
+ pub n_layers: usize, // number of layers
+ pub n_heads: usize, // number of query heads
+ pub n_kv_heads: usize, // number of key/value heads (can be < query heads because of multiquery)
+ pub vocab_size: usize, // vocabulary size, usually 256 (byte-level)
+ pub seq_len: usize, // max sequence length
+ pub norm_eps: f64,
+}
+
+#[derive(Clone)]
+pub struct Cache {
+ masks: Arc<Mutex<HashMap<usize, Tensor>>>,
+ pub use_kv_cache: bool,
+ #[allow(clippy::type_complexity)]
+ kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
+ cos: Tensor,
+ sin: Tensor,
+ device: Device,
+}
+
+impl Cache {
+ pub fn new(use_kv_cache: bool, cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let freq_cis_real = vb.get((cfg.seq_len, cfg.head_size() / 2), "freq_cis_real")?;
+ let freq_cis_imag = vb.get((cfg.seq_len, cfg.head_size() / 2), "freq_cis_imag")?;
+ let cos = freq_cis_real.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?;
+ let sin = freq_cis_imag.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?;
+ Ok(Self {
+ masks: Arc::new(Mutex::new(HashMap::new())),
+ use_kv_cache,
+ kvs: Arc::new(Mutex::new(vec![None; cfg.n_layers])),
+ cos,
+ sin,
+ device: vb.device().clone(),
+ })
+ }
+
+ fn mask(&self, t: usize) -> Result<Tensor> {
+ let mut masks = self.masks.lock().unwrap();
+ if let Some(mask) = masks.get(&t) {
+ Ok(mask.clone())
+ } else {
+ // TODO: If we support bool or u8 tensors, this would be better.
+ let mask: Vec<_> = (0..t)
+ .flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
+ .collect();
+ let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
+ masks.insert(t, mask.clone());
+ Ok(mask)
+ }
+ }
+}
+
+fn silu(xs: &Tensor) -> Result<Tensor> {
+ xs / (xs.neg()?.exp()? + 1.0)?
+}
+
+fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
+ let weight = vb.get((size2, size1), "weight")?;
+ Ok(Linear::new(weight, None))
+}
+
+fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
+ let embeddings = vb.get((cfg.vocab_size, cfg.dim), "weight")?;
+ Ok(Embedding::new(embeddings, cfg.dim))
+}
+
+struct RmsNorm {
+ scale: Tensor,
+ eps: f64,
+}
+
+impl RmsNorm {
+ fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
+ let scale = vb.get(size, "weight")?;
+ Ok(Self { scale, eps })
+ }
+
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let (b_sz, seq_len, hidden_size) = x.dims3()?;
+ let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
+ let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
+ let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
+ let size = self.scale.dims1()?;
+ let scale = self
+ .scale
+ .to_dtype(DType::F32)?
+ .broadcast_as((b_sz, seq_len, size))?;
+ let x = (scale * x_normed)?;
+ Ok(x)
+ }
+}
+
+struct CausalSelfAttention {
+ q_proj: Linear,
+ k_proj: Linear,
+ v_proj: Linear,
+ o_proj: Linear,
+ n_head: usize,
+ n_key_value_head: usize,
+ head_dim: usize,
+ cache: Cache,
+ max_seq_len: usize,
+}
+
+impl CausalSelfAttention {
+ fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
+ let (b_sz, seq_len, h, n_embd) = x.dims4()?;
+ let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
+ let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
+ let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
+ let sin = sin.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
+ let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;
+ let x0 = x.narrow(D::Minus1, 0, 1)?;
+ let x1 = x.narrow(D::Minus1, 1, 1)?;
+ let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
+ let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
+ let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?;
+ Ok(rope)
+ }
+
+ fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
+ let (b_sz, seq_len, n_embd) = x.dims3()?;
+ let q = self.q_proj.forward(x)?;
+ let k = self.k_proj.forward(x)?;
+ let v = self.v_proj.forward(x)?;
+
+ let q = q.reshape((b_sz, seq_len, self.n_head, self.head_dim))?;
+ let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
+ let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
+
+ let q = self.apply_rotary_emb(&q, index_pos)?;
+ let mut k = self.apply_rotary_emb(&k, index_pos)?;
+
+ if self.cache.use_kv_cache {
+ let mut cache = self.cache.kvs.lock().unwrap();
+ if let Some((cache_k, cache_v)) = &cache[block_idx] {
+ k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
+ v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?;
+ }
+ cache[block_idx] = Some((k.clone(), v.clone()))
+ }
+
+ let k = self.repeat_kv(k)?;
+ let v = self.repeat_kv(v)?;
+
+ let q = q.transpose(1, 2)?.contiguous()?;
+ let k = k.transpose(1, 2)?.contiguous()?;
+ let v = v.transpose(1, 2)?.contiguous()?;
+
+ let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
+ let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
+ let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
+ let att = att.softmax(D::Minus1)?;
+ // Convert to contiguous as matmul doesn't support strided vs for now.
+ let y = att.matmul(&v.contiguous()?)?;
+ let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
+ let y = self.o_proj.forward(&y)?;
+ Ok(y)
+ }
+
+ fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
+ let n_rep = self.n_head / self.n_key_value_head;
+ if n_rep == 1 {
+ Ok(x)
+ } else {
+ let (b_sz, seq_len, n_kv_head, head_dim) = x.dims4()?;
+ let x = x
+ .unsqueeze(3)?
+ .expand((b_sz, seq_len, n_kv_head, n_rep, head_dim))?
+ .reshape((b_sz, seq_len, n_kv_head * n_rep, head_dim))?;
+ Ok(x)
+ }
+ }
+
+ fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
+ let size_in = cfg.dim;
+ let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads;
+ let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads;
+ let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?;
+ let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?;
+ let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?;
+ let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?;
+ Ok(Self {
+ q_proj,
+ k_proj,
+ v_proj,
+ o_proj,
+ n_head: cfg.n_heads,
+ n_key_value_head: cfg.n_kv_heads,
+ head_dim: cfg.dim / cfg.n_heads,
+ cache: cache.clone(),
+ max_seq_len: cfg.seq_len,
+ })
+ }
+}
+
+fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
+ let shape = mask.shape();
+ let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
+ let m = mask.where_cond(&on_true, on_false)?;
+ Ok(m)
+}
+
+struct Mlp {
+ c_fc1: Linear,
+ c_fc2: Linear,
+ c_proj: Linear,
+}
+
+impl Mlp {
+ fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self {
+ Self {
+ c_fc1,
+ c_fc2,
+ c_proj,
+ }
+ }
+
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
+ self.c_proj.forward(&x)
+ }
+
+ fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let h_size = cfg.dim;
+ let i_size = cfg.hidden_dim;
+ let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?;
+ let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?;
+ let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?;
+ Ok(Self::new(c_fc1, c_fc2, c_proj))
+ }
+}
+
+struct Block {
+ rms_1: RmsNorm,
+ attn: CausalSelfAttention,
+ rms_2: RmsNorm,
+ mlp: Mlp,
+}
+
+impl Block {
+ fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
+ Self {
+ rms_1,
+ attn,
+ rms_2,
+ mlp,
+ }
+ }
+
+ fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
+ let residual = x;
+ let x = self.rms_1.forward(x)?;
+ let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
+ let residual = &x;
+ let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
+ Ok(x)
+ }
+
+ fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
+ let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
+ let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
+ let input_layernorm = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
+ let post_attention_layernorm =
+ RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?;
+ Ok(Self::new(
+ input_layernorm,
+ attn,
+ post_attention_layernorm,
+ mlp,
+ ))
+ }
+}
+
+pub struct Llama {
+ wte: Embedding,
+ blocks: Vec<Block>,
+ ln_f: RmsNorm,
+ lm_head: Linear,
+}
+
+impl Llama {
+ fn new(wte: Embedding, blocks: Vec<Block>, ln_f: RmsNorm, lm_head: Linear) -> Self {
+ Self {
+ wte,
+ blocks,
+ ln_f,
+ lm_head,
+ }
+ }
+
+ pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
+ let (_b_sz, seq_len) = x.dims2()?;
+ let mut x = self.wte.forward(x)?;
+ for (block_idx, block) in self.blocks.iter().enumerate() {
+ x = block.forward(&x, index_pos, block_idx)?;
+ }
+ let x = self.ln_f.forward(&x)?;
+ let x = x.i((.., seq_len - 1, ..))?;
+ let logits = self.lm_head.forward(&x)?;
+ logits.to_dtype(DType::F32)
+ }
+
+ pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
+ let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
+ let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
+ let norm = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
+ let blocks: Vec<_> = (0..cfg.n_layers)
+ .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap())
+ .collect();
+ Ok(Self::new(wte, blocks, norm, lm_head))
+ }
+}
diff --git a/candle-wasm-examples/llama2-c/src/worker.rs b/candle-wasm-examples/llama2-c/src/worker.rs
new file mode 100644
index 00000000..9b0351d6
--- /dev/null
+++ b/candle-wasm-examples/llama2-c/src/worker.rs
@@ -0,0 +1,353 @@
+use crate::model::{Cache, Config, Llama};
+use byteorder::{LittleEndian, ReadBytesExt};
+use candle::{DType, Device, IndexOp, Result, Shape, Tensor, D};
+use candle_nn::VarBuilder;
+use rand::{distributions::Distribution, SeedableRng};
+use serde::{Deserialize, Serialize};
+use wasm_bindgen::prelude::*;
+use yew_agent::{HandlerId, Public, WorkerLink};
+
+#[wasm_bindgen]
+extern "C" {
+ // Use `js_namespace` here to bind `console.log(..)` instead of just
+ // `log(..)`
+ #[wasm_bindgen(js_namespace = console)]
+ pub fn log(s: &str);
+}
+
+#[macro_export]
+macro_rules! console_log {
+ // Note that this is using the `log` function imported above during
+ // `bare_bones`
+ ($($t:tt)*) => ($crate::worker::log(&format_args!($($t)*).to_string()))
+}
+
+// Communication to the worker happens through bincode, the model weights and configs are fetched
+// on the main thread and transfered via the following structure.
+#[derive(Serialize, Deserialize)]
+pub struct ModelData {
+ pub tokenizer: Vec<u8>,
+ pub model: Vec<u8>,
+}
+
+fn read_i32<R: std::io::Read>(r: &mut R) -> Result<i32> {
+ let mut buf = [0u8; 4];
+ r.read_exact(&mut buf)?;
+ Ok(i32::from_le_bytes(buf))
+}
+
+fn read_tensor<R: std::io::Read, S: Into<Shape>>(
+ r: &mut R,
+ shape: S,
+ dev: &Device,
+) -> Result<Tensor> {
+ let shape = shape.into();
+ let mut data_t = vec![0f32; shape.elem_count()];
+ r.read_f32_into::<LittleEndian>(&mut data_t)?;
+ let tensor = Tensor::from_vec(data_t, shape, dev)?;
+ Ok(tensor)
+}
+
+struct Tokenizer {
+ tokens: Vec<String>,
+}
+
+impl Tokenizer {
+ fn from_reader<R: std::io::Read>(r: &mut R, c: &Config) -> Result<Self> {
+ let mut tokens = Vec::with_capacity(c.vocab_size);
+ for _token_index in 0..c.vocab_size {
+ let token_len = read_i32(r)?;
+ let mut token = vec![0u8; token_len as usize];
+ r.read_exact(&mut token)?;
+ tokens.push(String::from_utf8_lossy(&token).into_owned())
+ }
+ Ok(Self { tokens })
+ }
+}
+
+struct Model {
+ cache: Cache,
+ config: Config,
+ llama: Llama,
+ tokenizer: Tokenizer,
+}
+
+pub struct LogitsProcessor {
+ rng: rand::rngs::StdRng,
+ temperature: Option<f64>,
+}
+
+impl LogitsProcessor {
+ pub fn new(seed: u64, temperature: Option<f64>) -> Self {
+ Self {
+ rng: rand::rngs::StdRng::seed_from_u64(seed),
+ temperature,
+ }
+ }
+
+ pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
+ let logits = logits.to_dtype(DType::F32)?;
+ let next_token = if let Some(temperature) = self.temperature {
+ let prs = (&logits / temperature)?.softmax(D::Minus1)?;
+ let prs: Vec<f32> = prs.to_vec1()?;
+ let distr =
+ rand::distributions::WeightedIndex::new(prs).map_err(candle::Error::wrap)?;
+ distr.sample(&mut self.rng) as u32
+ } else {
+ let logits_v: Vec<f32> = logits.to_vec1()?;
+ logits_v
+ .iter()
+ .enumerate()
+ .max_by(|(_, u), (_, v)| u.total_cmp(v))
+ .map(|(i, _)| i as u32)
+ .unwrap()
+ };
+ Ok(next_token)
+ }
+}
+
+impl Model {
+ fn run(&self, link: &WorkerLink<Worker>, id: HandlerId) -> Result<()> {
+ let dev = Device::Cpu;
+ let mut logits_processor = LogitsProcessor::new(299792458, None);
+ let mut index_pos = 0;
+ let mut tokens = vec![1u32];
+
+ for index in 0..self.config.seq_len - 10 {
+ let context_size = if self.cache.use_kv_cache && index > 0 {
+ 1
+ } else {
+ tokens.len()
+ };
+ let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
+ let input = Tensor::new(ctxt, &dev)?.unsqueeze(0)?;
+ let logits = self.llama.forward(&input, index_pos)?;
+ let logits = logits.squeeze(0)?;
+ index_pos += ctxt.len();
+
+ let next_token = logits_processor.sample(&logits)?;
+ tokens.push(next_token);
+ let token = self.tokenizer.tokens[next_token as usize].clone();
+ link.respond(id, Ok(WorkerOutput::Generated(token)));
+ }
+ Ok(())
+ }
+}
+
+impl Config {
+ fn from_reader<R: std::io::Read>(r: &mut R) -> Result<Self> {
+ let dim = read_i32(r)? as usize;
+ let hidden_dim = read_i32(r)? as usize;
+ let n_layers = read_i32(r)? as usize;
+ let n_heads = read_i32(r)? as usize;
+ let n_kv_heads = read_i32(r)? as usize;
+ let vocab_size = read_i32(r)? as usize;
+ let seq_len = read_i32(r)? as usize;
+ Ok(Self {
+ dim,
+ hidden_dim,
+ n_layers,
+ n_heads,
+ n_kv_heads,
+ vocab_size,
+ seq_len,
+ norm_eps: 1e-5,
+ })
+ }
+
+ pub fn head_size(&self) -> usize {
+ self.dim / self.n_heads
+ }
+}
+
+struct TransformerWeights {
+ // token embedding table
+ token_embedding_table: Tensor, // (vocab_size, dim)
+ // weights for rmsnorms
+ rms_att_weight: Tensor, // (layer, dim) rmsnorm weights
+ rms_ffn_weight: Tensor, // (layer, dim)
+ // weights for matmuls
+ wq: Tensor, // (layer, dim, dim)
+ wk: Tensor, // (layer, dim, dim)
+ wv: Tensor, // (layer, dim, dim)
+ wo: Tensor, // (layer, dim, dim)
+ // weights for ffn
+ w1: Tensor, // (layer, hidden_dim, dim)
+ w2: Tensor, // (layer, dim, hidden_dim)
+ w3: Tensor, // (layer, hidden_dim, dim)
+ // final rmsnorm
+ rms_final_weight: Tensor, // (dim,)
+ // freq_cis for RoPE relatively positional embeddings
+ freq_cis_real: Tensor, // (seq_len, head_size/2)
+ freq_cis_imag: Tensor, // (seq_len, head_size/2)
+}
+
+impl TransformerWeights {
+ fn from_reader<R: std::io::Read>(r: &mut R, c: &Config, dev: &Device) -> Result<Self> {
+ let token_embedding_table = read_tensor(r, (c.vocab_size, c.dim), dev)?;
+ let rms_att_weight = read_tensor(r, (c.n_layers, c.dim), dev)?;
+ let wq = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?;
+ let wk = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?;
+ let wv = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?;
+ let wo = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?;
+ let rms_ffn_weight = read_tensor(r, (c.n_layers, c.dim), dev)?;
+ let w1 = read_tensor(r, (c.n_layers, c.hidden_dim, c.dim), dev)?;
+ let w2 = read_tensor(r, (c.n_layers, c.dim, c.hidden_dim), dev)?;
+ let w3 = read_tensor(r, (c.n_layers, c.hidden_dim, c.dim), dev)?;
+ let rms_final_weight = read_tensor(r, c.dim, dev)?;
+ let head_size = c.head_size();
+ let freq_cis_real = read_tensor(r, (c.seq_len, head_size / 2), dev)?;
+ let freq_cis_imag = read_tensor(r, (c.seq_len, head_size / 2), dev)?;
+ Ok(Self {
+ token_embedding_table,
+ rms_att_weight,
+ wq,
+ wk,
+ wv,
+ wo,
+ rms_ffn_weight,
+ w1,
+ w2,
+ w3,
+ rms_final_weight,
+ freq_cis_real,
+ freq_cis_imag,
+ })
+ }
+
+ fn var_builder(&self, cfg: &Config, device: &Device) -> Result<VarBuilder> {
+ let mut ws = std::collections::HashMap::new();
+ let mut insert = |name: &str, t: Tensor| {
+ ws.insert(name.to_string(), t);
+ };
+ insert("rot.freq_cis_real", self.freq_cis_real.clone());
+ insert("rot.freq_cis_imag", self.freq_cis_imag.clone());
+ insert(
+ "model.embed_tokens.weight",
+ self.token_embedding_table.clone(),
+ );
+ insert("lm_head.weight", self.token_embedding_table.clone());
+ insert("model.norm.weight", self.rms_final_weight.clone());
+ for layer in 0..cfg.n_layers {
+ ws.insert(
+ format!("model.layers.{layer}.self_attn.q_proj.weight"),
+ self.wq.i(layer)?,
+ );
+ ws.insert(
+ format!("model.layers.{layer}.self_attn.k_proj.weight"),
+ self.wk.i(layer)?,
+ );
+ ws.insert(
+ format!("model.layers.{layer}.self_attn.v_proj.weight"),
+ self.wv.i(layer)?,
+ );
+ ws.insert(
+ format!("model.layers.{layer}.self_attn.o_proj.weight"),
+ self.wo.i(layer)?,
+ );
+ ws.insert(
+ format!("model.layers.{layer}.mlp.gate_proj.weight"),
+ self.w1.i(layer)?,
+ );
+ ws.insert(
+ format!("model.layers.{layer}.mlp.down_proj.weight"),
+ self.w2.i(layer)?,
+ );
+ ws.insert(
+ format!("model.layers.{layer}.mlp.up_proj.weight"),
+ self.w3.i(layer)?,
+ );
+ ws.insert(
+ format!("model.layers.{layer}.input_layernorm.weight"),
+ self.rms_att_weight.i(layer)?,
+ );
+ ws.insert(
+ format!("model.layers.{layer}.post_attention_layernorm.weight"),
+ self.rms_ffn_weight.i(layer)?,
+ );
+ }
+ let vb = VarBuilder::from_tensors(ws, DType::F32, device);
+ Ok(vb)
+ }
+}
+
+impl Model {
+ fn load(md: ModelData) -> Result<Self> {
+ let dev = Device::Cpu;
+ let mut model = std::io::Cursor::new(md.model);
+ let config = Config::from_reader(&mut model)?;
+ let weights = TransformerWeights::from_reader(&mut model, &config, &dev)?;
+ let vb = weights.var_builder(&config, &dev)?;
+ let cache = Cache::new(true, &config, vb.pp("rot"))?;
+ let llama = Llama::load(vb, &cache, &config)?;
+ let mut tokenizer = std::io::Cursor::new(md.tokenizer);
+ let tokenizer = Tokenizer::from_reader(&mut tokenizer, &config)?;
+ Ok(Self {
+ cache,
+ config,
+ llama,
+ tokenizer,
+ })
+ }
+}
+
+pub struct Worker {
+ link: WorkerLink<Self>,
+ model: Option<Model>,
+}
+
+#[derive(Serialize, Deserialize)]
+pub enum WorkerInput {
+ ModelData(ModelData),
+ Run,
+}
+
+#[derive(Serialize, Deserialize)]
+pub enum WorkerOutput {
+ Generated(String),
+ GenerationDone(std::result::Result<(), String>),
+ WeightsLoaded,
+}
+
+impl yew_agent::Worker for Worker {
+ type Input = WorkerInput;
+ type Message = ();
+ type Output = std::result::Result<WorkerOutput, String>;
+ type Reach = Public<Self>;
+
+ fn create(link: WorkerLink<Self>) -> Self {
+ Self { link, model: None }
+ }
+
+ fn update(&mut self, _msg: Self::Message) {
+ // no messaging
+ }
+
+ fn handle_input(&mut self, msg: Self::Input, id: HandlerId) {
+ let output = match msg {
+ WorkerInput::ModelData(md) => match Model::load(md) {
+ Ok(model) => {
+ self.model = Some(model);
+ Ok(WorkerOutput::WeightsLoaded)
+ }
+ Err(err) => Err(format!("model creation error {err:?}")),
+ },
+ WorkerInput::Run => match &self.model {
+ None => Err("model has not been set yet".to_string()),
+ Some(model) => {
+ let result = model.run(&self.link, id).map_err(|e| e.to_string());
+ Ok(WorkerOutput::GenerationDone(result))
+ }
+ },
+ };
+ self.link.respond(id, output);
+ }
+
+ fn name_of_resource() -> &'static str {
+ "worker.js"
+ }
+
+ fn resource_path_is_relative() -> bool {
+ true
+ }
+}
diff --git a/candle-wasm-examples/whisper/Cargo.toml b/candle-wasm-examples/whisper/Cargo.toml
new file mode 100644
index 00000000..55b507db
--- /dev/null
+++ b/candle-wasm-examples/whisper/Cargo.toml
@@ -0,0 +1,52 @@
+[package]
+name = "candle-wasm-example-whisper"
+version = "0.1.0"
+edition = "2021"
+
+description = "Wasm example for the candle ML framework."
+repository = "https://github.com/LaurentMazare/candle"
+keywords = ["blas", "tensor", "machine-learning"]
+categories = ["science"]
+license = "MIT/Apache-2.0"
+readme = "README.md"
+
+[dependencies]
+candle = { path = "../../candle-core" }
+candle-nn = { path = "../../candle-nn" }
+num-traits = { workspace = true }
+tokenizers = { workspace = true, features = ["unstable_wasm"] }
+
+# App crates.
+anyhow = { workspace = true }
+log = { workspace = true }
+rand = { workspace = true }
+serde = { workspace = true }
+serde_json = { workspace = true }
+wav = { workspace = true }
+
+# Wasm specific crates.
+getrandom = { version = "0.2", features = ["js"] }
+gloo = "0.8"
+js-sys = "0.3.64"
+wasm-bindgen = "0.2.87"
+wasm-bindgen-futures = "0.4.37"
+wasm-logger = "0.2"
+yew-agent = "0.2.0"
+yew = { version = "0.20.0", features = ["csr"] }
+
+[dependencies.web-sys]
+version = "0.3.64"
+features = [
+ 'Blob',
+ 'Document',
+ 'Element',
+ 'HtmlElement',
+ 'Node',
+ 'Window',
+ 'Request',
+ 'RequestCache',
+ 'RequestInit',
+ 'RequestMode',
+ 'Response',
+ 'Performance',
+]
diff --git a/candle-wasm-examples/whisper/index.html b/candle-wasm-examples/whisper/index.html
new file mode 100644
index 00000000..7a21c4f2
--- /dev/null
+++ b/candle-wasm-examples/whisper/index.html
@@ -0,0 +1,24 @@
+<!DOCTYPE html>
+<html lang="en">
+ <head>
+ <meta charset="utf-8" />
+ <title>Welcome to Candle!</title>
+
+ <link data-trunk rel="copy-file" href="jfk.wav" />
+ <link data-trunk rel="copy-file" href="mm0.wav" />
+ <link data-trunk rel="copy-file" href="a13.wav" />
+ <link data-trunk rel="copy-file" href="gb0.wav" />
+ <link data-trunk rel="copy-file" href="gb1.wav" />
+ <link data-trunk rel="copy-file" href="hp0.wav" />
+ <link data-trunk rel="copy-file" href="tokenizer.en.json" />
+ <link data-trunk rel="copy-file" href="mel_filters.safetensors" />
+ <link data-trunk rel="copy-file" href="tiny.en.safetensors" />
+ <link data-trunk rel="rust" href="Cargo.toml" data-bin="app" data-type="main" />
+ <link data-trunk rel="rust" href="Cargo.toml" data-bin="worker" data-type="worker" />
+
+ <link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Roboto:300,300italic,700,700italic">
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/normalize/8.0.1/normalize.css">
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/milligram/1.4.1/milligram.css">
+ </head>
+ <body></body>
+</html>
diff --git a/candle-wasm-examples/whisper/main.js b/candle-wasm-examples/whisper/main.js
new file mode 100644
index 00000000..c27e0d60
--- /dev/null
+++ b/candle-wasm-examples/whisper/main.js
@@ -0,0 +1,6 @@
+import init, { run_app } from './pkg/candle_wasm_example_whisper.js';
+async function main() {
+ await init('/pkg/candle_wasm_example_whisper_bg.wasm');
+ run_app();
+}
+main()
diff --git a/candle-wasm-examples/whisper/src/app.rs b/candle-wasm-examples/whisper/src/app.rs
new file mode 100644
index 00000000..23519ebd
--- /dev/null
+++ b/candle-wasm-examples/whisper/src/app.rs
@@ -0,0 +1,238 @@
+use crate::console_log;
+use crate::worker::{ModelData, Segment, Worker, WorkerInput, WorkerOutput};
+use js_sys::Date;
+use wasm_bindgen::prelude::*;
+use wasm_bindgen_futures::JsFuture;
+use yew::{html, Component, Context, Html};
+use yew_agent::{Bridge, Bridged};
+
+const SAMPLE_NAMES: [&str; 6] = [
+ "jfk.wav", "a13.wav", "gb0.wav", "gb1.wav", "hp0.wav", "mm0.wav",
+];
+
+async fn fetch_url(url: &str) -> Result<Vec<u8>, JsValue> {
+ use web_sys::{Request, RequestCache, RequestInit, RequestMode, Response};
+ let window = web_sys::window().ok_or("window")?;
+ let mut opts = RequestInit::new();
+ let opts = opts
+ .method("GET")
+ .mode(RequestMode::Cors)
+ .cache(RequestCache::NoCache);
+
+ let request = Request::new_with_str_and_init(url, opts)?;
+
+ let resp_value = JsFuture::from(window.fetch_with_request(&request)).await?;
+
+ // `resp_value` is a `Response` object.
+ assert!(resp_value.is_instance_of::<Response>());
+ let resp: Response = resp_value.dyn_into()?;
+ let data = JsFuture::from(resp.blob()?).await?;
+ let blob = web_sys::Blob::from(data);
+ let array_buffer = JsFuture::from(blob.array_buffer()).await?;
+ let data = js_sys::Uint8Array::new(&array_buffer).to_vec();
+ Ok(data)
+}
+
+pub enum Msg {
+ Run(usize),
+ UpdateStatus(String),
+ SetDecoder(ModelData),
+ WorkerInMsg(WorkerInput),
+ WorkerOutMsg(Result<WorkerOutput, String>),
+}
+
+pub struct CurrentDecode {
+ start_time: Option<f64>,
+}
+
+pub struct App {
+ status: String,
+ segments: Vec<Segment>,
+ current_decode: Option<CurrentDecode>,
+ worker: Box<dyn Bridge<Worker>>,
+}
+
+async fn model_data_load() -> Result<ModelData, JsValue> {
+ let tokenizer = fetch_url("tokenizer.en.json").await?;
+ let mel_filters = fetch_url("mel_filters.safetensors").await?;
+ let weights = fetch_url("tiny.en.safetensors").await?;
+ console_log!("{}", weights.len());
+ Ok(ModelData {
+ tokenizer,
+ mel_filters,
+ weights,
+ })
+}
+
+fn performance_now() -> Option<f64> {
+ let window = web_sys::window()?;
+ let performance = window.performance()?;
+ Some(performance.now() / 1000.)
+}
+
+impl Component for App {
+ type Message = Msg;
+ type Properties = ();
+
+ fn create(ctx: &Context<Self>) -> Self {
+ let status = "loading weights".to_string();
+ let cb = {
+ let link = ctx.link().clone();
+ move |e| link.send_message(Self::Message::WorkerOutMsg(e))
+ };
+ let worker = Worker::bridge(std::rc::Rc::new(cb));
+ Self {
+ status,
+ segments: vec![],
+ current_decode: None,
+ worker,
+ }
+ }
+
+ fn rendered(&mut self, ctx: &Context<Self>, first_render: bool) {
+ if first_render {
+ ctx.link().send_future(async {
+ match model_data_load().await {
+ Err(err) => {
+ let status = format!("{err:?}");
+ Msg::UpdateStatus(status)
+ }
+ Ok(model_data) => Msg::SetDecoder(model_data),
+ }
+ });
+ }
+ }
+
+ fn update(&mut self, ctx: &Context<Self>, msg: Self::Message) -> bool {
+ match msg {
+ Msg::SetDecoder(md) => {
+ self.status = "weights loaded succesfully!".to_string();
+ console_log!("loaded weights");
+ self.worker.send(WorkerInput::ModelData(md));
+ true
+ }
+ Msg::Run(sample_index) => {
+ let sample = SAMPLE_NAMES[sample_index];
+ if self.current_decode.is_some() {
+ self.status = "already decoding some sample at the moment".to_string()
+ } else {
+ let start_time = performance_now();
+ self.current_decode = Some(CurrentDecode { start_time });
+ self.status = format!("decoding {sample}");
+ self.segments.clear();
+ ctx.link().send_future(async move {
+ match fetch_url(sample).await {
+ Err(err) => {
+ let output = Err(format!("decoding error: {err:?}"));
+ // Mimic a worker output to so as to release current_decode
+ Msg::WorkerOutMsg(output)
+ }
+ Ok(wav_bytes) => {
+ Msg::WorkerInMsg(WorkerInput::DecodeTask { wav_bytes })
+ }
+ }
+ })
+ }
+ //
+ true
+ }
+ Msg::WorkerOutMsg(output) => {
+ let dt = self.current_decode.as_ref().and_then(|current_decode| {
+ current_decode.start_time.and_then(|start_time| {
+ performance_now().map(|stop_time| stop_time - start_time)
+ })
+ });
+ self.current_decode = None;
+ match output {
+ Ok(WorkerOutput::WeightsLoaded) => self.status = "weights loaded!".to_string(),
+ Ok(WorkerOutput::Decoded(segments)) => {
+ self.status = match dt {
+ None => "decoding succeeded!".to_string(),
+ Some(dt) => format!("decoding succeeded in {:.2}s", dt),
+ };
+ self.segments = segments;
+ }
+ Err(err) => {
+ self.status = format!("decoding error {err:?}");
+ }
+ }
+ true
+ }
+ Msg::WorkerInMsg(inp) => {
+ self.worker.send(inp);
+ true
+ }
+ Msg::UpdateStatus(status) => {
+ self.status = status;
+ true
+ }
+ }
+ }
+
+ fn view(&self, ctx: &Context<Self>) -> Html {
+ html! {
+ <div>
+ <table>
+ <thead>
+ <tr>
+ <th>{"Sample"}</th>
+ <th></th>
+ <th></th>
+ </tr>
+ </thead>
+ <tbody>
+ {
+ SAMPLE_NAMES.iter().enumerate().map(|(i, name)| { html! {
+ <tr>
+ <th>{name}</th>
+ <th><audio controls=true src={format!("./{name}")}></audio></th>
+ <th><button class="button" onclick={ctx.link().callback(move |_| Msg::Run(i))}> { "run" }</button></th>
+ </tr>
+ }
+ }).collect::<Html>()
+ }
+ </tbody>
+ </table>
+ <h2>
+ {&self.status}
+ </h2>
+ {
+ if self.current_decode.is_some() {
+ html! { <progress id="progress-bar" aria-label="decoding…"></progress> }
+ } else { html!{
+ <blockquote>
+ <p>
+ {
+ self.segments.iter().map(|segment| { html! {
+ <>
+ <i>
+ {
+ format!("{:.2}s-{:.2}s: (avg-logprob: {:.4}, no-speech-prob: {:.4})",
+ segment.start,
+ segment.start + segment.duration,
+ segment.dr.avg_logprob,
+ segment.dr.no_speech_prob,
+ )
+ }
+ </i>
+ <br/ >
+ {&segment.dr.text}
+ <br/ >
+ </>
+ } }).collect::<Html>()
+ }
+ </p>
+ </blockquote>
+ }
+ }
+ }
+
+ // Display the current date and time the page was rendered
+ <p class="footer">
+ { "Rendered: " }
+ { String::from(Date::new_0().to_string()) }
+ </p>
+ </div>
+ }
+ }
+}
diff --git a/candle-wasm-examples/whisper/src/audio.rs b/candle-wasm-examples/whisper/src/audio.rs
new file mode 100644
index 00000000..5b414368
--- /dev/null
+++ b/candle-wasm-examples/whisper/src/audio.rs
@@ -0,0 +1,217 @@
+// Audio processing code, adapted from whisper.cpp
+// https://github.com/ggerganov/whisper.cpp
+use super::worker;
+
+pub trait Float: num_traits::Float + num_traits::FloatConst + num_traits::NumAssign {}
+
+impl Float for f32 {}
+impl Float for f64 {}
+
+// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2357
+fn fft<T: Float>(inp: &[T]) -> Vec<T> {
+ let n = inp.len();
+ let zero = T::zero();
+ if n == 1 {
+ return vec![inp[0], zero];
+ }
+ if n % 2 == 1 {
+ return dft(inp);
+ }
+ let mut out = vec![zero; n * 2];
+
+ let mut even = vec![];
+ even.reserve(n / 2);
+ let mut odd = vec![];
+ odd.reserve(n / 2);
+
+ for (i, &inp) in inp.iter().enumerate() {
+ if i % 2 == 0 {
+ even.push(inp)
+ } else {
+ odd.push(inp);
+ }
+ }
+
+ let even_fft = fft(&even);
+ let odd_fft = fft(&odd);
+
+ let two_pi = T::PI() + T::PI();
+ let n_t = T::from(n).unwrap();
+ for k in 0..n / 2 {
+ let k_t = T::from(k).unwrap();
+ let theta = two_pi * k_t / n_t;
+ let re = theta.cos();
+ let im = -theta.sin();
+
+ let re_odd = odd_fft[2 * k];
+ let im_odd = odd_fft[2 * k + 1];
+
+ out[2 * k] = even_fft[2 * k] + re * re_odd - im * im_odd;
+ out[2 * k + 1] = even_fft[2 * k + 1] + re * im_odd + im * re_odd;
+
+ out[2 * (k + n / 2)] = even_fft[2 * k] - re * re_odd + im * im_odd;
+ out[2 * (k + n / 2) + 1] = even_fft[2 * k + 1] - re * im_odd - im * re_odd;
+ }
+ out
+}
+
+// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2337
+fn dft<T: Float>(inp: &[T]) -> Vec<T> {
+ let zero = T::zero();
+ let n = inp.len();
+ let two_pi = T::PI() + T::PI();
+
+ let mut out = Vec::new();
+ out.reserve(2 * n);
+ let n_t = T::from(n).unwrap();
+ for k in 0..n {
+ let k_t = T::from(k).unwrap();
+ let mut re = zero;
+ let mut im = zero;
+
+ for (j, &inp) in inp.iter().enumerate() {
+ let j_t = T::from(j).unwrap();
+ let angle = two_pi * k_t * j_t / n_t;
+ re += inp * angle.cos();
+ im -= inp * angle.sin();
+ }
+
+ out.push(re);
+ out.push(im);
+ }
+ out
+}
+
+#[allow(clippy::too_many_arguments)]
+// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2414
+fn log_mel_spectrogram_w<T: Float>(
+ ith: usize,
+ hann: &[T],
+ samples: &[T],
+ filters: &[T],
+ fft_size: usize,
+ fft_step: usize,
+ speed_up: bool,
+ n_len: usize,
+ n_mel: usize,
+ n_threads: usize,
+) -> Vec<T> {
+ let n_fft = if speed_up {
+ 1 + fft_size / 4
+ } else {
+ 1 + fft_size / 2
+ };
+
+ let zero = T::zero();
+ let half = T::from(0.5).unwrap();
+ let mut fft_in = vec![zero; fft_size];
+ let mut mel = vec![zero; n_len * n_mel];
+
+ for i in (ith..n_len).step_by(n_threads) {
+ let offset = i * fft_step;
+
+ // apply Hanning window
+ for j in 0..fft_size {
+ fft_in[j] = if offset + j < samples.len() {
+ hann[j] * samples[offset + j]
+ } else {
+ zero
+ }
+ }
+
+ // FFT -> mag^2
+ let mut fft_out: Vec<T> = fft(&fft_in);
+
+ for j in 0..fft_size {
+ fft_out[j] = fft_out[2 * j] * fft_out[2 * j] + fft_out[2 * j + 1] * fft_out[2 * j + 1];
+ }
+ for j in 1..fft_size / 2 {
+ let v = fft_out[fft_size - j];
+ fft_out[j] += v;
+ }
+
+ if speed_up {
+ // scale down in the frequency domain results in a speed up in the time domain
+ for j in 0..n_fft {
+ fft_out[j] = half * (fft_out[2 * j] + fft_out[2 * j + 1]);
+ }
+ }
+
+ // mel spectrogram
+ for j in 0..n_mel {
+ let mut sum = zero;
+ for k in 0..n_fft {
+ sum += fft_out[k] * filters[j * n_fft + k];
+ }
+ mel[j * n_len + i] = T::max(sum, T::from(1e-10).unwrap()).log10();
+ }
+ }
+ mel
+}
+
+fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
+ samples: &[T],
+ filters: &[T],
+ fft_size: usize,
+ fft_step: usize,
+ n_mel: usize,
+ speed_up: bool,
+) -> Vec<T> {
+ let zero = T::zero();
+ let two_pi = T::PI() + T::PI();
+ let half = T::from(0.5).unwrap();
+ let one = T::from(1.0).unwrap();
+ let four = T::from(4.0).unwrap();
+ let fft_size_t = T::from(fft_size).unwrap();
+
+ let hann: Vec<T> = (0..fft_size)
+ .map(|i| half * (one - ((two_pi * T::from(i).unwrap()) / fft_size_t).cos()))
+ .collect();
+ let n_len = samples.len() / fft_step;
+
+ // pad audio with at least one extra chunk of zeros
+ let pad = 100 * worker::CHUNK_LENGTH / 2;
+ let n_len = if n_len % pad != 0 {
+ (n_len / pad + 1) * pad
+ } else {
+ n_len
+ };
+ let n_len = n_len + pad;
+ let samples = {
+ let mut samples_padded = samples.to_vec();
+ let to_add = n_len * fft_step - samples.len();
+ samples_padded.extend(std::iter::repeat(zero).take(to_add));
+ samples_padded
+ };
+
+ // Use a single thread for now.
+ let mut mel = log_mel_spectrogram_w(
+ 0, &hann, &samples, filters, fft_size, fft_step, speed_up, n_len, n_mel, 1,
+ );
+ let mmax = mel
+ .iter()
+ .max_by(|&u, &v| u.partial_cmp(v).unwrap_or(std::cmp::Ordering::Greater))
+ .copied()
+ .unwrap_or(zero)
+ - T::from(8).unwrap();
+ for m in mel.iter_mut() {
+ let v = T::max(*m, mmax);
+ *m = v / four + one
+ }
+ mel
+}
+
+pub fn pcm_to_mel<T: Float + std::fmt::Display>(
+ samples: &[T],
+ filters: &[T],
+) -> anyhow::Result<Vec<T>> {
+ let mel = log_mel_spectrogram_(
+ samples,
+ filters,
+ worker::N_FFT,
+ worker::HOP_LENGTH,
+ worker::N_MELS,
+ false,
+ );
+ Ok(mel)
+}
diff --git a/candle-wasm-examples/whisper/src/bin/app.rs b/candle-wasm-examples/whisper/src/bin/app.rs
new file mode 100644
index 00000000..89efa7f7
--- /dev/null
+++ b/candle-wasm-examples/whisper/src/bin/app.rs
@@ -0,0 +1,4 @@
+fn main() {
+ wasm_logger::init(wasm_logger::Config::new(log::Level::Trace));
+ yew::Renderer::<candle_wasm_example_whisper::App>::new().render();
+}
diff --git a/candle-wasm-examples/whisper/src/bin/worker.rs b/candle-wasm-examples/whisper/src/bin/worker.rs
new file mode 100644
index 00000000..b8c16b56
--- /dev/null
+++ b/candle-wasm-examples/whisper/src/bin/worker.rs
@@ -0,0 +1,4 @@
+use yew_agent::PublicWorker;
+fn main() {
+ candle_wasm_example_whisper::Worker::register();
+}
diff --git a/candle-wasm-examples/whisper/src/lib.rs b/candle-wasm-examples/whisper/src/lib.rs
new file mode 100644
index 00000000..b47d43ca
--- /dev/null
+++ b/candle-wasm-examples/whisper/src/lib.rs
@@ -0,0 +1,31 @@
+#![allow(dead_code)]
+
+pub const WITH_TIMER: bool = true;
+
+struct Timer {
+ label: &'static str,
+}
+
+impl Timer {
+ fn new(label: &'static str) -> Self {
+ if WITH_TIMER {
+ web_sys::console::time_with_label(label);
+ }
+ Self { label }
+ }
+}
+
+impl Drop for Timer {
+ fn drop(&mut self) {
+ if WITH_TIMER {
+ web_sys::console::time_end_with_label(self.label)
+ }
+ }
+}
+
+mod app;
+mod audio;
+mod model;
+mod worker;
+pub use app::App;
+pub use worker::Worker;
diff --git a/candle-wasm-examples/whisper/src/model.rs b/candle-wasm-examples/whisper/src/model.rs
new file mode 100644
index 00000000..97eff839
--- /dev/null
+++ b/candle-wasm-examples/whisper/src/model.rs
@@ -0,0 +1,421 @@
+#![allow(dead_code)]
+// We use anyhow rather than candle errors as it provides better support for getting the backtrace
+// back when using RUST_LIB_BACKTRACE=1.
+use anyhow::Result;
+use candle::{Device, Tensor};
+use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder};
+use serde::Deserialize;
+
+// The names in comments correspond to the original implementation:
+// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17
+#[derive(Debug, Clone, PartialEq, Deserialize)]
+pub struct Config {
+ pub num_mel_bins: usize, // n_mels
+ pub max_source_positions: usize, // n_audio_ctx
+ pub d_model: usize, // n_audio_state
+ pub encoder_attention_heads: usize, // n_audio_head
+ pub encoder_layers: usize, // n_audio_layer
+ pub vocab_size: usize, // n_vocab
+ pub max_target_positions: usize, // n_text_ctx
+ // pub n_text_state: usize,
+ pub decoder_attention_heads: usize, // n_text_head
+ pub decoder_layers: usize, // n_text_layer
+}
+
+impl Config {
+ pub fn tiny_en() -> Self {
+ Self {
+ num_mel_bins: 80,
+ vocab_size: 51864,
+ max_source_positions: 1500,
+ d_model: 384,
+ encoder_attention_heads: 6,
+ encoder_layers: 4,
+ max_target_positions: 448,
+ // n_text_state: 384,
+ decoder_attention_heads: 6,
+ decoder_layers: 4,
+ }
+ }
+}
+
+// The struct below is duplicated from candle_nn::Linear so that it's easier to add some wasm
+// specific monitoring.
+#[derive(Debug)]
+struct Linear {
+ weight: Tensor,
+ bias: Option<Tensor>,
+}
+
+impl Linear {
+ fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
+ Self { weight, bias }
+ }
+
+ fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
+ let _timer = crate::Timer::new("Linear::forward");
+ let w = match x.dims() {
+ &[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
+ _ => self.weight.t()?,
+ };
+ let x = {
+ let _timer = crate::Timer::new("Linear::matmul");
+ x.matmul(&w)?
+ };
+ match &self.bias {
+ None => Ok(x),
+ Some(bias) => x.broadcast_add(bias),
+ }
+ }
+}
+
+fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
+ let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
+ Ok(Embedding::new(embeddings, hidden_size))
+}
+
+fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
+ let weight = vb.get((size2, size1), "weight")?;
+ let bias = vb.get(size2, "bias")?;
+ Ok(Linear::new(weight, Some(bias)))
+}
+
+fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
+ let weight = vb.get((size2, size1), "weight")?;
+ Ok(Linear::new(weight, None))
+}
+
+fn conv1d(
+ in_channels: usize,
+ out_channels: usize,
+ kernel_size: usize,
+ config: Conv1dConfig,
+ vb: VarBuilder,
+) -> Result<Conv1d> {
+ let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?;
+ let bias = vb.get(out_channels, "bias")?;
+ Ok(Conv1d::new(weight, Some(bias), config))
+}
+
+fn conv1d_no_bias(
+ in_channels: usize,
+ out_channels: usize,
+ kernel_size: usize,
+ config: Conv1dConfig,
+ vb: VarBuilder,
+) -> Result<Conv1d> {
+ let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?;
+ Ok(Conv1d::new(weight, None, config))
+}
+
+struct Dropout {
+ pr: f64,
+}
+
+impl Dropout {
+ fn new(pr: f64) -> Self {
+ Self { pr }
+ }
+
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ // TODO
+ Ok(x.clone())
+ }
+}
+
+fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> {
+ let weight = vb.get(size, "weight")?;
+ let bias = vb.get(size, "bias")?;
+ Ok(LayerNorm::new(weight, bias, 1e-5))
+}
+
+// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62
+struct MultiHeadAttention {
+ query: Linear,
+ key: Linear,
+ value: Linear,
+ out: Linear,
+ n_head: usize,
+}
+
+impl MultiHeadAttention {
+ fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> {
+ let query = linear(n_state, n_state, vb.pp("q_proj"))?;
+ let value = linear(n_state, n_state, vb.pp("v_proj"))?;
+ let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?;
+ let out = linear(n_state, n_state, vb.pp("out_proj"))?;
+ Ok(Self {
+ query,
+ key,
+ value,
+ out,
+ n_head,
+ })
+ }
+
+ fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
+ let _timer = crate::Timer::new("MultiHeadAttention::forward");
+ let q = self.query.forward(x)?;
+ let k = self.key.forward(xa.unwrap_or(x))?;
+ let v = self.value.forward(xa.unwrap_or(x))?;
+ let wv = self.qkv_attention(&q, &k, &v, mask)?;
+ let out = self.out.forward(&wv)?;
+ Ok(out)
+ }
+
+ fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
+ let (n_batch, n_ctx, n_state) = x.dims3()?;
+ let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
+ Ok(x.reshape(target_dims)?.transpose(1, 2)?)
+ }
+
+ fn qkv_attention(
+ &self,
+ q: &Tensor,
+ k: &Tensor,
+ v: &Tensor,
+ mask: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let (_, n_ctx, n_state) = q.dims3()?;
+ let scale = ((n_state / self.n_head) as f64).powf(-0.25);
+ let q = {
+ let _timer = crate::Timer::new("q::reshape");
+ (self.reshape_head(q)? * scale)?
+ };
+ let k = {
+ let _timer = crate::Timer::new("k::reshape");
+ (self.reshape_head(k)?.transpose(2, 3)? * scale)?
+ };
+ let v = {
+ let _timer = crate::Timer::new("v::reshape-contiguous");
+ self.reshape_head(v)?.contiguous()?
+ };
+ let mut qk = {
+ let _timer = crate::Timer::new("qk::matmul");
+ q.matmul(&k)?
+ };
+ if let Some(mask) = mask {
+ let mask = mask.narrow(0, 0, n_ctx)?.narrow(1, 0, n_ctx)?;
+ qk = qk.broadcast_add(&mask)?
+ }
+ let w = {
+ let _timer = crate::Timer::new("qk::softmax");
+ qk.softmax(candle::D::Minus1)?
+ };
+ let wv = {
+ let _timer = crate::Timer::new("wv::matmul");
+ w.matmul(&v)?.transpose(1, 2)?.flatten_from(2)?
+ };
+ Ok(wv)
+ }
+}
+
+// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111
+struct ResidualAttentionBlock {
+ attn: MultiHeadAttention,
+ attn_ln: LayerNorm,
+ cross_attn: Option<(MultiHeadAttention, LayerNorm)>,
+ mlp_linear1: Linear,
+ mlp_linear2: Linear,
+ mlp_ln: LayerNorm,
+}
+
+impl ResidualAttentionBlock {
+ fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> {
+ let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?;
+ let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?;
+ let cross_attn = if ca {
+ let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?;
+ let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?;
+ Some((cross_attn, cross_attn_ln))
+ } else {
+ None
+ };
+ let n_mlp = n_state * 4;
+ let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?;
+ let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?;
+ let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?;
+ Ok(Self {
+ attn,
+ attn_ln,
+ cross_attn,
+ mlp_linear1,
+ mlp_linear2,
+ mlp_ln,
+ })
+ }
+
+ fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
+ let _timer = crate::Timer::new("ResidualAttentionBlock::forward");
+ let attn = self.attn.forward(&self.attn_ln.forward(x)?, None, mask)?;
+ let mut x = (x + attn)?;
+ if let Some((attn, ln)) = &self.cross_attn {
+ x = (&x + attn.forward(&ln.forward(&x)?, xa, None)?)?;
+ }
+ let mlp = self.mlp_linear2.forward(
+ &self
+ .mlp_linear1
+ .forward(&self.mlp_ln.forward(&x)?)?
+ .gelu()?,
+ )?;
+ Ok((x + mlp)?)
+ }
+}
+
+fn sinusoids(length: usize, channels: usize) -> Result<Tensor> {
+ let max_timescale = 10000f32;
+ let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32;
+ let inv_timescales: Vec<_> = (0..channels / 2)
+ .map(|i| (i as f32 * (-log_timescale_increment)).exp())
+ .collect();
+ let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
+ let arange = Tensor::arange(0, length as u32, &Device::Cpu)?
+ .to_dtype(candle::DType::F32)?
+ .unsqueeze(1)?;
+ let sh = (length, channels / 2);
+ let scaled_time = (arange.broadcast_as(sh)? * inv_timescales.broadcast_as(sh)?)?;
+ let sincos = Tensor::cat(&[scaled_time.sin()?, scaled_time.cos()?], 1)?;
+ Ok(sincos)
+}
+
+// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143
+pub struct AudioEncoder {
+ conv1: Conv1d,
+ conv2: Conv1d,
+ positional_embedding: Tensor,
+ blocks: Vec<ResidualAttentionBlock>,
+ ln_post: LayerNorm,
+}
+
+impl AudioEncoder {
+ fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let n_state = cfg.d_model;
+ let n_head = cfg.encoder_attention_heads;
+ let n_ctx = cfg.max_source_positions;
+ let cfg1 = Conv1dConfig {
+ padding: 1,
+ stride: 1,
+ };
+ let cfg2 = Conv1dConfig {
+ padding: 1,
+ stride: 2,
+ };
+ let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
+ let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
+ let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?;
+ let blocks = (0..cfg.encoder_layers)
+ .map(|i| {
+ ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(&format!("layers.{i}")))
+ })
+ .collect::<Result<Vec<_>>>()?;
+ let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?;
+ Ok(Self {
+ conv1,
+ conv2,
+ positional_embedding,
+ blocks,
+ ln_post,
+ })
+ }
+ pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let _timer = crate::Timer::new("AudioEncoder::forward");
+ let x = {
+ let _timer = crate::Timer::new("conv1::forward");
+ self.conv1.forward(x)?.gelu()?
+ };
+ let x = {
+ let _timer = crate::Timer::new("conv2::forward");
+ self.conv2.forward(&x)?.gelu()?
+ };
+ let x = x.transpose(1, 2)?;
+ let (_bsize, seq_len, _hidden) = x.dims3()?;
+ let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
+ let mut x = x.broadcast_add(&positional_embedding)?;
+ for block in self.blocks.iter() {
+ x = block.forward(&x, None, None)?
+ }
+ let x = self.ln_post.forward(&x)?;
+ Ok(x)
+ }
+}
+
+// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176
+pub struct TextDecoder {
+ token_embedding: Embedding,
+ positional_embedding: Tensor,
+ blocks: Vec<ResidualAttentionBlock>,
+ ln: LayerNorm,
+ mask: Tensor,
+}
+
+impl TextDecoder {
+ fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let _timer = crate::Timer::new("TextDecoder::forward");
+ let n_state = cfg.d_model;
+ let n_head = cfg.decoder_attention_heads;
+ let n_ctx = cfg.max_target_positions;
+ let token_embedding = embedding(cfg.vocab_size, n_state, vb.pp("embed_tokens"))?;
+ let positional_embedding = vb.get((n_ctx, n_state), "embed_positions.weight")?;
+ let blocks = (0..cfg.decoder_layers)
+ .map(|i| {
+ ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(&format!("layers.{i}")))
+ })
+ .collect::<Result<Vec<_>>>()?;
+ let ln = layer_norm(n_state, vb.pp("layer_norm"))?;
+ let mask: Vec<_> = (0..n_ctx)
+ .flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
+ .collect();
+ let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?;
+
+ Ok(Self {
+ token_embedding,
+ positional_embedding,
+ blocks,
+ ln,
+ mask,
+ })
+ }
+
+ pub fn forward(&self, x: &Tensor, xa: &Tensor) -> Result<Tensor> {
+ let x_dims = x.dims();
+ let last = x_dims[x_dims.len() - 1];
+ let token_embedding = self.token_embedding.forward(x)?;
+ let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
+ let mut x = token_embedding.broadcast_add(&positional_embedding)?;
+ for block in self.blocks.iter() {
+ x = block.forward(&x, Some(xa), Some(&self.mask))?;
+ }
+ let x = self.ln.forward(&x)?;
+ let w = self
+ .token_embedding
+ .embeddings()
+ .broadcast_left(x_dims[0])?;
+ let logits = x.matmul(&w.t()?)?;
+ Ok(logits)
+ }
+}
+
+// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221
+pub struct Whisper {
+ pub encoder: AudioEncoder,
+ pub decoder: TextDecoder,
+ pub config: Config,
+}
+
+impl Whisper {
+ pub fn load(vb: &VarBuilder, config: Config) -> Result<Self> {
+ let encoder = AudioEncoder::load(vb.pp("model.encoder"), &config)?;
+ let decoder = TextDecoder::load(vb.pp("model.decoder"), &config)?;
+ Ok(Self {
+ encoder,
+ decoder,
+ config,
+ })
+ }
+
+ pub fn forward(&self, mel: &Tensor, tokens: &Tensor) -> Result<Tensor> {
+ let enc = self.encoder.forward(mel)?;
+ let dec = self.decoder.forward(tokens, &enc)?;
+ Ok(dec)
+ }
+}
diff --git a/candle-wasm-examples/whisper/src/worker.rs b/candle-wasm-examples/whisper/src/worker.rs
new file mode 100644
index 00000000..ea64bf02
--- /dev/null
+++ b/candle-wasm-examples/whisper/src/worker.rs
@@ -0,0 +1,345 @@
+use crate::model::{Config, Whisper};
+use anyhow::Error as E;
+use candle::{safetensors::Load, DType, Device, Tensor};
+use candle_nn::VarBuilder;
+use rand::{distributions::Distribution, rngs::StdRng, SeedableRng};
+use serde::{Deserialize, Serialize};
+use tokenizers::Tokenizer;
+use wasm_bindgen::prelude::*;
+use yew_agent::{HandlerId, Public, WorkerLink};
+
+#[wasm_bindgen]
+extern "C" {
+ // Use `js_namespace` here to bind `console.log(..)` instead of just
+ // `log(..)`
+ #[wasm_bindgen(js_namespace = console)]
+ pub fn log(s: &str);
+}
+
+#[macro_export]
+macro_rules! console_log {
+ // Note that this is using the `log` function imported above during
+ // `bare_bones`
+ ($($t:tt)*) => ($crate::worker::log(&format_args!($($t)*).to_string()))
+}
+
+pub const DTYPE: DType = DType::F32;
+
+// Audio parameters.
+pub const SAMPLE_RATE: usize = 16000;
+pub const N_FFT: usize = 400;
+pub const N_MELS: usize = 80;
+pub const HOP_LENGTH: usize = 160;
+pub const CHUNK_LENGTH: usize = 30;
+pub const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
+pub const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input
+pub const N_SAMPLES_PER_TOKEN: usize = HOP_LENGTH * 2; // the initial convolutions has stride 2
+pub const FRAMES_PER_SECOND: usize = SAMPLE_RATE / HOP_LENGTH; // 10ms per audio frame
+pub const TOKENS_PER_SECOND: usize = SAMPLE_RATE / N_SAMPLES_PER_TOKEN; // 20ms per audio token
+
+pub const NO_SPEECH_THRESHOLD: f64 = 0.6;
+pub const LOGPROB_THRESHOLD: f64 = -1.0;
+pub const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
+pub const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
+
+// Tokenizer dependent bits.
+pub const SOT_TOKEN: u32 = 50257;
+pub const EOT_TOKEN: u32 = 50256;
+pub const NO_SPEECH_TOKEN: u32 = 50361;
+pub const NO_TIMESTAMP_TOKEN: u32 = 50362;
+// From the _get_suppress_tokens function + 50362 (no timestamp)
+// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605
+pub const SUPPRESS_TOKENS: [u32; 91] = [
+ 1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357,
+ 366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391, 1635, 1782,
+ 1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959,
+ 10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, 16410, 16791, 17992,
+ 19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409, 34949, 40283, 40493, 40549,
+ 47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361, 50362,
+];
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct DecodingResult {
+ pub tokens: Vec<u32>,
+ pub text: String,
+ pub avg_logprob: f64,
+ pub no_speech_prob: f64,
+ temperature: f64,
+ compression_ratio: f64,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct Segment {
+ pub start: f64,
+ pub duration: f64,
+ pub dr: DecodingResult,
+}
+
+pub struct Decoder {
+ model: Whisper,
+ mel_filters: Vec<f32>,
+ tokenizer: Tokenizer,
+ suppress_tokens: Tensor,
+}
+
+impl Decoder {
+ fn new(
+ model: Whisper,
+ tokenizer: Tokenizer,
+ mel_filters: Vec<f32>,
+ device: &Device,
+ ) -> anyhow::Result<Self> {
+ let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
+ .map(|i| {
+ if SUPPRESS_TOKENS.contains(&i) {
+ f32::NEG_INFINITY
+ } else {
+ 0f32
+ }
+ })
+ .collect();
+ let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
+ Ok(Self {
+ model,
+ mel_filters,
+ tokenizer,
+ suppress_tokens,
+ })
+ }
+
+ fn decode(&self, mel: &Tensor, t: f64, rng: &mut StdRng) -> anyhow::Result<DecodingResult> {
+ let model = &self.model;
+ let audio_features = model.encoder.forward(mel)?;
+ console_log!("audio features: {:?}", audio_features.dims());
+ let sample_len = model.config.max_target_positions / 2;
+ let mut sum_logprob = 0f64;
+ let mut no_speech_prob = f64::NAN;
+ let mut tokens = vec![SOT_TOKEN];
+ for i in 0..sample_len {
+ let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
+
+ // The model expects a batch dim but this inference loop does not handle
+ // it so we add it at this point.
+ let tokens_t = tokens_t.unsqueeze(0)?;
+ let logits = model.decoder.forward(&tokens_t, &audio_features)?;
+ let logits = logits.squeeze(0)?;
+
+ // Extract the no speech probability on the first iteration by looking at the first
+ // token logits and the probability for the according token.
+ if i == 0 {
+ no_speech_prob = logits
+ .get(0)?
+ .softmax(0)?
+ .get(NO_SPEECH_TOKEN as usize)?
+ .to_scalar::<f32>()? as f64;
+ }
+
+ let (seq_len, _) = logits.dims2()?;
+ let logits = logits
+ .get(seq_len - 1)?
+ .broadcast_add(&self.suppress_tokens)?;
+ let next_token = if t > 0f64 {
+ let prs = (&logits / t)?.softmax(0)?;
+ let logits_v: Vec<f32> = prs.to_vec1()?;
+ let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
+ distr.sample(rng) as u32
+ } else {
+ let logits_v: Vec<f32> = logits.to_vec1()?;
+ logits_v
+ .iter()
+ .enumerate()
+ .max_by(|(_, u), (_, v)| u.total_cmp(v))
+ .map(|(i, _)| i as u32)
+ .unwrap()
+ };
+ tokens.push(next_token);
+ let prob = logits
+ .softmax(candle::D::Minus1)?
+ .get(next_token as usize)?
+ .to_scalar::<f32>()? as f64;
+ if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions {
+ break;
+ }
+ sum_logprob += prob.ln();
+ }
+ let text = self
+ .tokenizer
+ .decode(tokens.clone(), true)
+ .map_err(E::msg)?;
+ let avg_logprob = sum_logprob / tokens.len() as f64;
+
+ Ok(DecodingResult {
+ tokens,
+ text,
+ avg_logprob,
+ no_speech_prob,
+ temperature: t,
+ compression_ratio: f64::NAN,
+ })
+ }
+
+ fn decode_with_fallback(
+ &self,
+ segment: &Tensor,
+ rng: &mut StdRng,
+ ) -> anyhow::Result<DecodingResult> {
+ for (i, &t) in TEMPERATURES.iter().enumerate() {
+ let dr: Result<DecodingResult, _> = self.decode(segment, t, rng);
+ if i == TEMPERATURES.len() - 1 {
+ return dr;
+ }
+ // On errors, we try again with a different temperature.
+ match dr {
+ Ok(dr) => {
+ let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD
+ || dr.avg_logprob < LOGPROB_THRESHOLD;
+ if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD {
+ return Ok(dr);
+ }
+ }
+ Err(err) => {
+ console_log!("Error running at {t}: {err}")
+ }
+ }
+ }
+ unreachable!()
+ }
+
+ fn run(&self, mel: &Tensor) -> anyhow::Result<Vec<Segment>> {
+ let mut rng = StdRng::seed_from_u64(299792458);
+ let (_, _, content_frames) = mel.dims3()?;
+ let mut seek = 0;
+ let mut segments = vec![];
+ while seek < content_frames {
+ let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
+ let segment_size = usize::min(content_frames - seek, N_FRAMES);
+ let mel_segment = mel.narrow(2, seek, segment_size)?;
+ let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
+ let dr = self.decode_with_fallback(&mel_segment, &mut rng)?;
+ seek += segment_size;
+ if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD {
+ console_log!("no speech detected, skipping {seek} {dr:?}");
+ continue;
+ }
+ let segment = Segment {
+ start: time_offset,
+ duration: segment_duration,
+ dr,
+ };
+ console_log!("{seek}: {segment:?}");
+ segments.push(segment)
+ }
+ Ok(segments)
+ }
+
+ fn load(md: ModelData) -> anyhow::Result<Self> {
+ let device = Device::Cpu;
+ let tokenizer = Tokenizer::from_bytes(&md.tokenizer).map_err(anyhow::Error::msg)?;
+
+ let mel_filters = candle::safetensors::SafeTensors::deserialize(&md.mel_filters)?;
+ let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?;
+ console_log!("loaded mel filters {:?}", mel_filters.shape());
+ let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
+ let weights = candle::safetensors::SafeTensors::deserialize(&md.weights)?;
+ let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
+ let config = Config::tiny_en();
+ let whisper = Whisper::load(&vb, config)?;
+ console_log!("done loading model");
+ let decoder = Self::new(whisper, tokenizer, mel_filters, &device)?;
+ Ok(decoder)
+ }
+
+ fn convert_and_run(&self, wav_input: &[u8]) -> anyhow::Result<Vec<Segment>> {
+ let device = Device::Cpu;
+ let mut wav_input = std::io::Cursor::new(wav_input);
+ let (header, data) = wav::read(&mut wav_input)?;
+ console_log!("loaded wav data: {header:?}");
+ if header.sampling_rate != SAMPLE_RATE as u32 {
+ anyhow::bail!("wav file must have a {SAMPLE_RATE} sampling rate");
+ }
+ let data = data.as_sixteen().expect("expected 16 bit wav file");
+ let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
+ .iter()
+ .map(|v| *v as f32 / 32768.)
+ .collect();
+ console_log!("pcm data loaded {}", pcm_data.len());
+ let mel = crate::audio::pcm_to_mel(&pcm_data, &self.mel_filters)?;
+ let mel_len = mel.len();
+ let mel = Tensor::from_vec(mel, (1, N_MELS, mel_len / N_MELS), &device)?;
+ console_log!("loaded mel: {:?}", mel.dims());
+ let segments = self.run(&mel)?;
+ Ok(segments)
+ }
+}
+
+// Communication to the worker happens through bincode, the model weights and configs are fetched
+// on the main thread and transfered via the following structure.
+#[derive(Serialize, Deserialize)]
+pub struct ModelData {
+ pub tokenizer: Vec<u8>,
+ pub mel_filters: Vec<u8>,
+ pub weights: Vec<u8>,
+}
+
+pub struct Worker {
+ link: WorkerLink<Self>,
+ decoder: Option<Decoder>,
+}
+
+#[derive(Serialize, Deserialize)]
+pub enum WorkerInput {
+ ModelData(ModelData),
+ DecodeTask { wav_bytes: Vec<u8> },
+}
+
+#[derive(Serialize, Deserialize)]
+pub enum WorkerOutput {
+ Decoded(Vec<Segment>),
+ WeightsLoaded,
+}
+
+impl yew_agent::Worker for Worker {
+ type Input = WorkerInput;
+ type Message = ();
+ type Output = Result<WorkerOutput, String>;
+ type Reach = Public<Self>;
+
+ fn create(link: WorkerLink<Self>) -> Self {
+ Self {
+ link,
+ decoder: None,
+ }
+ }
+
+ fn update(&mut self, _msg: Self::Message) {
+ // no messaging
+ }
+
+ fn handle_input(&mut self, msg: Self::Input, id: HandlerId) {
+ let output = match msg {
+ WorkerInput::ModelData(md) => match Decoder::load(md) {
+ Ok(decoder) => {
+ self.decoder = Some(decoder);
+ Ok(WorkerOutput::WeightsLoaded)
+ }
+ Err(err) => Err(format!("model creation error {err:?}")),
+ },
+ WorkerInput::DecodeTask { wav_bytes } => match &self.decoder {
+ None => Err("model has not been set".to_string()),
+ Some(decoder) => decoder
+ .convert_and_run(&wav_bytes)
+ .map(WorkerOutput::Decoded)
+ .map_err(|e| e.to_string()),
+ },
+ };
+ self.link.respond(id, output);
+ }
+
+ fn name_of_resource() -> &'static str {
+ "worker.js"
+ }
+
+ fn resource_path_is_relative() -> bool {
+ true
+ }
+}