diff options
Diffstat (limited to 'candle-wasm-examples')
-rw-r--r-- | candle-wasm-examples/llama2-c/Cargo.toml | 51 | ||||
-rw-r--r-- | candle-wasm-examples/llama2-c/index.html | 17 | ||||
-rw-r--r-- | candle-wasm-examples/llama2-c/src/app.rs | 188 | ||||
-rw-r--r-- | candle-wasm-examples/llama2-c/src/bin/app.rs | 4 | ||||
-rw-r--r-- | candle-wasm-examples/llama2-c/src/bin/worker.rs | 4 | ||||
-rw-r--r-- | candle-wasm-examples/llama2-c/src/lib.rs | 30 | ||||
-rw-r--r-- | candle-wasm-examples/llama2-c/src/model.rs | 321 | ||||
-rw-r--r-- | candle-wasm-examples/llama2-c/src/worker.rs | 353 | ||||
-rw-r--r-- | candle-wasm-examples/whisper/Cargo.toml | 52 | ||||
-rw-r--r-- | candle-wasm-examples/whisper/index.html | 24 | ||||
-rw-r--r-- | candle-wasm-examples/whisper/main.js | 6 | ||||
-rw-r--r-- | candle-wasm-examples/whisper/src/app.rs | 238 | ||||
-rw-r--r-- | candle-wasm-examples/whisper/src/audio.rs | 217 | ||||
-rw-r--r-- | candle-wasm-examples/whisper/src/bin/app.rs | 4 | ||||
-rw-r--r-- | candle-wasm-examples/whisper/src/bin/worker.rs | 4 | ||||
-rw-r--r-- | candle-wasm-examples/whisper/src/lib.rs | 31 | ||||
-rw-r--r-- | candle-wasm-examples/whisper/src/model.rs | 421 | ||||
-rw-r--r-- | candle-wasm-examples/whisper/src/worker.rs | 345 |
18 files changed, 2310 insertions, 0 deletions
diff --git a/candle-wasm-examples/llama2-c/Cargo.toml b/candle-wasm-examples/llama2-c/Cargo.toml new file mode 100644 index 00000000..22d9cfe8 --- /dev/null +++ b/candle-wasm-examples/llama2-c/Cargo.toml @@ -0,0 +1,51 @@ +[package] +name = "candle-wasm-example-llama2" +version = "0.1.0" +edition = "2021" + +description = "Wasm example for the candle ML framework." +repository = "https://github.com/LaurentMazare/candle" +keywords = ["blas", "tensor", "machine-learning"] +categories = ["science"] +license = "MIT/Apache-2.0" +readme = "README.md" + +[dependencies] +candle = { path = "../../candle-core" } +candle-nn = { path = "../../candle-nn" } +num-traits = { workspace = true } + +# App crates. +anyhow = { workspace = true } +byteorder = { workspace = true } +log = { workspace = true } +rand = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } + +# Wasm specific crates. +getrandom = { version = "0.2", features = ["js"] } +gloo = "0.8" +js-sys = "0.3.64" +wasm-bindgen = "0.2.87" +wasm-bindgen-futures = "0.4.37" +wasm-logger = "0.2" +yew-agent = "0.2.0" +yew = { version = "0.20.0", features = ["csr"] } + +[dependencies.web-sys] +version = "0.3.64" +features = [ + 'Blob', + 'Document', + 'Element', + 'HtmlElement', + 'Node', + 'Window', + 'Request', + 'RequestCache', + 'RequestInit', + 'RequestMode', + 'Response', + 'Performance', +] diff --git a/candle-wasm-examples/llama2-c/index.html b/candle-wasm-examples/llama2-c/index.html new file mode 100644 index 00000000..e98e1ecb --- /dev/null +++ b/candle-wasm-examples/llama2-c/index.html @@ -0,0 +1,17 @@ +<!DOCTYPE html> +<html lang="en"> + <head> + <meta charset="utf-8" /> + <title>Welcome to Candle!</title> + + <link data-trunk rel="copy-file" href="tokenizer.bin" /> + <link data-trunk rel="copy-file" href="model.bin" /> + <link data-trunk rel="rust" href="Cargo.toml" data-bin="app" data-type="main" /> + <link data-trunk rel="rust" href="Cargo.toml" data-bin="worker" data-type="worker" /> + + <link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Roboto:300,300italic,700,700italic"> + <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/normalize/8.0.1/normalize.css"> + <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/milligram/1.4.1/milligram.css"> + </head> + <body></body> +</html> diff --git a/candle-wasm-examples/llama2-c/src/app.rs b/candle-wasm-examples/llama2-c/src/app.rs new file mode 100644 index 00000000..460ac053 --- /dev/null +++ b/candle-wasm-examples/llama2-c/src/app.rs @@ -0,0 +1,188 @@ +use crate::console_log; +use crate::worker::{ModelData, Worker, WorkerInput, WorkerOutput}; +use wasm_bindgen::prelude::*; +use wasm_bindgen_futures::JsFuture; +use yew::{html, Component, Context, Html}; +use yew_agent::{Bridge, Bridged}; + +async fn fetch_url(url: &str) -> Result<Vec<u8>, JsValue> { + use web_sys::{Request, RequestCache, RequestInit, RequestMode, Response}; + let window = web_sys::window().ok_or("window")?; + let mut opts = RequestInit::new(); + let opts = opts + .method("GET") + .mode(RequestMode::Cors) + .cache(RequestCache::NoCache); + + let request = Request::new_with_str_and_init(url, opts)?; + + let resp_value = JsFuture::from(window.fetch_with_request(&request)).await?; + + // `resp_value` is a `Response` object. + assert!(resp_value.is_instance_of::<Response>()); + let resp: Response = resp_value.dyn_into()?; + let data = JsFuture::from(resp.blob()?).await?; + let blob = web_sys::Blob::from(data); + let array_buffer = JsFuture::from(blob.array_buffer()).await?; + let data = js_sys::Uint8Array::new(&array_buffer).to_vec(); + Ok(data) +} + +pub enum Msg { + Run, + UpdateStatus(String), + SetModel(ModelData), + WorkerInMsg(WorkerInput), + WorkerOutMsg(Result<WorkerOutput, String>), +} + +pub struct CurrentDecode { + start_time: Option<f64>, +} + +pub struct App { + status: String, + generated: String, + current_decode: Option<CurrentDecode>, + worker: Box<dyn Bridge<Worker>>, +} + +async fn model_data_load() -> Result<ModelData, JsValue> { + let tokenizer = fetch_url("tokenizer.bin").await?; + let model = fetch_url("model.bin").await?; + console_log!("{}", model.len()); + Ok(ModelData { tokenizer, model }) +} + +fn performance_now() -> Option<f64> { + let window = web_sys::window()?; + let performance = window.performance()?; + Some(performance.now() / 1000.) +} + +impl Component for App { + type Message = Msg; + type Properties = (); + + fn create(ctx: &Context<Self>) -> Self { + let status = "loading weights".to_string(); + let cb = { + let link = ctx.link().clone(); + move |e| link.send_message(Self::Message::WorkerOutMsg(e)) + }; + let worker = Worker::bridge(std::rc::Rc::new(cb)); + Self { + status, + generated: String::new(), + current_decode: None, + worker, + } + } + + fn rendered(&mut self, ctx: &Context<Self>, first_render: bool) { + if first_render { + ctx.link().send_future(async { + match model_data_load().await { + Err(err) => { + let status = format!("{err:?}"); + Msg::UpdateStatus(status) + } + Ok(model_data) => Msg::SetModel(model_data), + } + }); + } + } + + fn update(&mut self, ctx: &Context<Self>, msg: Self::Message) -> bool { + match msg { + Msg::SetModel(md) => { + self.status = "weights loaded succesfully!".to_string(); + console_log!("loaded weights"); + self.worker.send(WorkerInput::ModelData(md)); + true + } + Msg::Run => { + if self.current_decode.is_some() { + self.status = "already generating some sample at the moment".to_string() + } else { + let start_time = performance_now(); + self.current_decode = Some(CurrentDecode { start_time }); + self.status = "generating...".to_string(); + self.generated.clear(); + ctx.link().send_message(Msg::WorkerInMsg(WorkerInput::Run)) + } + true + } + Msg::WorkerOutMsg(output) => { + match output { + Ok(WorkerOutput::WeightsLoaded) => self.status = "weights loaded!".to_string(), + Ok(WorkerOutput::GenerationDone(Err(err))) => { + self.status = format!("error in worker process: {err}"); + self.current_decode = None + } + Ok(WorkerOutput::GenerationDone(Ok(()))) => { + let dt = self.current_decode.as_ref().and_then(|current_decode| { + current_decode.start_time.and_then(|start_time| { + performance_now().map(|stop_time| stop_time - start_time) + }) + }); + self.status = match dt { + None => "generation succeeded!".to_string(), + Some(dt) => format!("generation succeeded in {:.2}s", dt), + }; + self.current_decode = None + } + Ok(WorkerOutput::Generated(token)) => self.generated.push_str(&token), + Err(err) => { + self.status = format!("error in worker {err:?}"); + } + } + true + } + Msg::WorkerInMsg(inp) => { + self.worker.send(inp); + true + } + Msg::UpdateStatus(status) => { + self.status = status; + true + } + } + } + + fn view(&self, ctx: &Context<Self>) -> Html { + html! { + <div> + <div><p>{"Running "} + <a href="https://github.com/karpathy/llama2.c" target="_blank">{"llama2.c"}</a> + {" in the browser using rust/wasm with "} + <a href="https://github.com/LaurentMazare/candle" target="_blank">{"candle!"}</a> + </p> + <p>{"Once the weights have loaded, click on the run button to start generating content."} + </p> + </div> + <button class="button" onclick={ctx.link().callback(move |_| Msg::Run)}> { "run" }</button> + <br/ > + <h3> + {&self.status} + </h3> + { + if self.current_decode.is_some() { + html! { <progress id="progress-bar" aria-label="generating…"></progress> } + } else { + html! {} + } + } + <blockquote> + <p> { self.generated.chars().map(|c| + if c == '\r' || c == '\n' { + html! { <br/> } + } else { + html! { {c} } + }).collect::<Html>() + } </p> + </blockquote> + </div> + } + } +} diff --git a/candle-wasm-examples/llama2-c/src/bin/app.rs b/candle-wasm-examples/llama2-c/src/bin/app.rs new file mode 100644 index 00000000..3428f6ff --- /dev/null +++ b/candle-wasm-examples/llama2-c/src/bin/app.rs @@ -0,0 +1,4 @@ +fn main() { + wasm_logger::init(wasm_logger::Config::new(log::Level::Trace)); + yew::Renderer::<candle_wasm_example_llama2::App>::new().render(); +} diff --git a/candle-wasm-examples/llama2-c/src/bin/worker.rs b/candle-wasm-examples/llama2-c/src/bin/worker.rs new file mode 100644 index 00000000..d8ca2172 --- /dev/null +++ b/candle-wasm-examples/llama2-c/src/bin/worker.rs @@ -0,0 +1,4 @@ +use yew_agent::PublicWorker; +fn main() { + candle_wasm_example_llama2::Worker::register(); +} diff --git a/candle-wasm-examples/llama2-c/src/lib.rs b/candle-wasm-examples/llama2-c/src/lib.rs new file mode 100644 index 00000000..61154d04 --- /dev/null +++ b/candle-wasm-examples/llama2-c/src/lib.rs @@ -0,0 +1,30 @@ +#![allow(dead_code)] + +pub const WITH_TIMER: bool = true; + +struct Timer { + label: &'static str, +} + +impl Timer { + fn new(label: &'static str) -> Self { + if WITH_TIMER { + web_sys::console::time_with_label(label); + } + Self { label } + } +} + +impl Drop for Timer { + fn drop(&mut self) { + if WITH_TIMER { + web_sys::console::time_end_with_label(self.label) + } + } +} + +mod app; +mod model; +mod worker; +pub use app::App; +pub use worker::Worker; diff --git a/candle-wasm-examples/llama2-c/src/model.rs b/candle-wasm-examples/llama2-c/src/model.rs new file mode 100644 index 00000000..13f939db --- /dev/null +++ b/candle-wasm-examples/llama2-c/src/model.rs @@ -0,0 +1,321 @@ +use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle_nn::{Embedding, Linear, VarBuilder}; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; + +#[derive(Debug, Clone)] +pub struct Config { + pub dim: usize, // transformer dimension + pub hidden_dim: usize, // for ffn layers + pub n_layers: usize, // number of layers + pub n_heads: usize, // number of query heads + pub n_kv_heads: usize, // number of key/value heads (can be < query heads because of multiquery) + pub vocab_size: usize, // vocabulary size, usually 256 (byte-level) + pub seq_len: usize, // max sequence length + pub norm_eps: f64, +} + +#[derive(Clone)] +pub struct Cache { + masks: Arc<Mutex<HashMap<usize, Tensor>>>, + pub use_kv_cache: bool, + #[allow(clippy::type_complexity)] + kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>, + cos: Tensor, + sin: Tensor, + device: Device, +} + +impl Cache { + pub fn new(use_kv_cache: bool, cfg: &Config, vb: VarBuilder) -> Result<Self> { + let freq_cis_real = vb.get((cfg.seq_len, cfg.head_size() / 2), "freq_cis_real")?; + let freq_cis_imag = vb.get((cfg.seq_len, cfg.head_size() / 2), "freq_cis_imag")?; + let cos = freq_cis_real.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?; + let sin = freq_cis_imag.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?; + Ok(Self { + masks: Arc::new(Mutex::new(HashMap::new())), + use_kv_cache, + kvs: Arc::new(Mutex::new(vec![None; cfg.n_layers])), + cos, + sin, + device: vb.device().clone(), + }) + } + + fn mask(&self, t: usize) -> Result<Tensor> { + let mut masks = self.masks.lock().unwrap(); + if let Some(mask) = masks.get(&t) { + Ok(mask.clone()) + } else { + // TODO: If we support bool or u8 tensors, this would be better. + let mask: Vec<_> = (0..t) + .flat_map(|i| (0..t).map(move |j| u32::from(j > i))) + .collect(); + let mask = Tensor::from_slice(&mask, (t, t), &self.device)?; + masks.insert(t, mask.clone()); + Ok(mask) + } + } +} + +fn silu(xs: &Tensor) -> Result<Tensor> { + xs / (xs.neg()?.exp()? + 1.0)? +} + +fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> { + let weight = vb.get((size2, size1), "weight")?; + Ok(Linear::new(weight, None)) +} + +fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> { + let embeddings = vb.get((cfg.vocab_size, cfg.dim), "weight")?; + Ok(Embedding::new(embeddings, cfg.dim)) +} + +struct RmsNorm { + scale: Tensor, + eps: f64, +} + +impl RmsNorm { + fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> { + let scale = vb.get(size, "weight")?; + Ok(Self { scale, eps }) + } + + fn forward(&self, x: &Tensor) -> Result<Tensor> { + let (b_sz, seq_len, hidden_size) = x.dims3()?; + let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; + let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?; + let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?; + let size = self.scale.dims1()?; + let scale = self + .scale + .to_dtype(DType::F32)? + .broadcast_as((b_sz, seq_len, size))?; + let x = (scale * x_normed)?; + Ok(x) + } +} + +struct CausalSelfAttention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + o_proj: Linear, + n_head: usize, + n_key_value_head: usize, + head_dim: usize, + cache: Cache, + max_seq_len: usize, +} + +impl CausalSelfAttention { + fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { + let (b_sz, seq_len, h, n_embd) = x.dims4()?; + let cos = self.cache.cos.narrow(0, index_pos, seq_len)?; + let sin = self.cache.sin.narrow(0, index_pos, seq_len)?; + let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?; + let sin = sin.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?; + let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?; + let x0 = x.narrow(D::Minus1, 0, 1)?; + let x1 = x.narrow(D::Minus1, 1, 1)?; + let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?; + let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?; + let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?; + Ok(rope) + } + + fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> { + let (b_sz, seq_len, n_embd) = x.dims3()?; + let q = self.q_proj.forward(x)?; + let k = self.k_proj.forward(x)?; + let v = self.v_proj.forward(x)?; + + let q = q.reshape((b_sz, seq_len, self.n_head, self.head_dim))?; + let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?; + let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?; + + let q = self.apply_rotary_emb(&q, index_pos)?; + let mut k = self.apply_rotary_emb(&k, index_pos)?; + + if self.cache.use_kv_cache { + let mut cache = self.cache.kvs.lock().unwrap(); + if let Some((cache_k, cache_v)) = &cache[block_idx] { + k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?; + v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?; + } + cache[block_idx] = Some((k.clone(), v.clone())) + } + + let k = self.repeat_kv(k)?; + let v = self.repeat_kv(v)?; + + let q = q.transpose(1, 2)?.contiguous()?; + let k = k.transpose(1, 2)?.contiguous()?; + let v = v.transpose(1, 2)?.contiguous()?; + + let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; + let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?; + let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; + let att = att.softmax(D::Minus1)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + let y = att.matmul(&v.contiguous()?)?; + let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; + let y = self.o_proj.forward(&y)?; + Ok(y) + } + + fn repeat_kv(&self, x: Tensor) -> Result<Tensor> { + let n_rep = self.n_head / self.n_key_value_head; + if n_rep == 1 { + Ok(x) + } else { + let (b_sz, seq_len, n_kv_head, head_dim) = x.dims4()?; + let x = x + .unsqueeze(3)? + .expand((b_sz, seq_len, n_kv_head, n_rep, head_dim))? + .reshape((b_sz, seq_len, n_kv_head * n_rep, head_dim))?; + Ok(x) + } + } + + fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { + let size_in = cfg.dim; + let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads; + let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads; + let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?; + let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?; + let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?; + let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?; + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + n_head: cfg.n_heads, + n_key_value_head: cfg.n_kv_heads, + head_dim: cfg.dim / cfg.n_heads, + cache: cache.clone(), + max_seq_len: cfg.seq_len, + }) + } +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { + let shape = mask.shape(); + let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true, on_false)?; + Ok(m) +} + +struct Mlp { + c_fc1: Linear, + c_fc2: Linear, + c_proj: Linear, +} + +impl Mlp { + fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self { + Self { + c_fc1, + c_fc2, + c_proj, + } + } + + fn forward(&self, x: &Tensor) -> Result<Tensor> { + let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?; + self.c_proj.forward(&x) + } + + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let h_size = cfg.dim; + let i_size = cfg.hidden_dim; + let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?; + let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?; + let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?; + Ok(Self::new(c_fc1, c_fc2, c_proj)) + } +} + +struct Block { + rms_1: RmsNorm, + attn: CausalSelfAttention, + rms_2: RmsNorm, + mlp: Mlp, +} + +impl Block { + fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self { + Self { + rms_1, + attn, + rms_2, + mlp, + } + } + + fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> { + let residual = x; + let x = self.rms_1.forward(x)?; + let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?; + let residual = &x; + let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?; + Ok(x) + } + + fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { + let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?; + let mlp = Mlp::load(vb.pp("mlp"), cfg)?; + let input_layernorm = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?; + let post_attention_layernorm = + RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?; + Ok(Self::new( + input_layernorm, + attn, + post_attention_layernorm, + mlp, + )) + } +} + +pub struct Llama { + wte: Embedding, + blocks: Vec<Block>, + ln_f: RmsNorm, + lm_head: Linear, +} + +impl Llama { + fn new(wte: Embedding, blocks: Vec<Block>, ln_f: RmsNorm, lm_head: Linear) -> Self { + Self { + wte, + blocks, + ln_f, + lm_head, + } + } + + pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { + let (_b_sz, seq_len) = x.dims2()?; + let mut x = self.wte.forward(x)?; + for (block_idx, block) in self.blocks.iter().enumerate() { + x = block.forward(&x, index_pos, block_idx)?; + } + let x = self.ln_f.forward(&x)?; + let x = x.i((.., seq_len - 1, ..))?; + let logits = self.lm_head.forward(&x)?; + logits.to_dtype(DType::F32) + } + + pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { + let wte = embedding(cfg, vb.pp("model.embed_tokens"))?; + let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?; + let norm = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; + let blocks: Vec<_> = (0..cfg.n_layers) + .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap()) + .collect(); + Ok(Self::new(wte, blocks, norm, lm_head)) + } +} diff --git a/candle-wasm-examples/llama2-c/src/worker.rs b/candle-wasm-examples/llama2-c/src/worker.rs new file mode 100644 index 00000000..9b0351d6 --- /dev/null +++ b/candle-wasm-examples/llama2-c/src/worker.rs @@ -0,0 +1,353 @@ +use crate::model::{Cache, Config, Llama}; +use byteorder::{LittleEndian, ReadBytesExt}; +use candle::{DType, Device, IndexOp, Result, Shape, Tensor, D}; +use candle_nn::VarBuilder; +use rand::{distributions::Distribution, SeedableRng}; +use serde::{Deserialize, Serialize}; +use wasm_bindgen::prelude::*; +use yew_agent::{HandlerId, Public, WorkerLink}; + +#[wasm_bindgen] +extern "C" { + // Use `js_namespace` here to bind `console.log(..)` instead of just + // `log(..)` + #[wasm_bindgen(js_namespace = console)] + pub fn log(s: &str); +} + +#[macro_export] +macro_rules! console_log { + // Note that this is using the `log` function imported above during + // `bare_bones` + ($($t:tt)*) => ($crate::worker::log(&format_args!($($t)*).to_string())) +} + +// Communication to the worker happens through bincode, the model weights and configs are fetched +// on the main thread and transfered via the following structure. +#[derive(Serialize, Deserialize)] +pub struct ModelData { + pub tokenizer: Vec<u8>, + pub model: Vec<u8>, +} + +fn read_i32<R: std::io::Read>(r: &mut R) -> Result<i32> { + let mut buf = [0u8; 4]; + r.read_exact(&mut buf)?; + Ok(i32::from_le_bytes(buf)) +} + +fn read_tensor<R: std::io::Read, S: Into<Shape>>( + r: &mut R, + shape: S, + dev: &Device, +) -> Result<Tensor> { + let shape = shape.into(); + let mut data_t = vec![0f32; shape.elem_count()]; + r.read_f32_into::<LittleEndian>(&mut data_t)?; + let tensor = Tensor::from_vec(data_t, shape, dev)?; + Ok(tensor) +} + +struct Tokenizer { + tokens: Vec<String>, +} + +impl Tokenizer { + fn from_reader<R: std::io::Read>(r: &mut R, c: &Config) -> Result<Self> { + let mut tokens = Vec::with_capacity(c.vocab_size); + for _token_index in 0..c.vocab_size { + let token_len = read_i32(r)?; + let mut token = vec![0u8; token_len as usize]; + r.read_exact(&mut token)?; + tokens.push(String::from_utf8_lossy(&token).into_owned()) + } + Ok(Self { tokens }) + } +} + +struct Model { + cache: Cache, + config: Config, + llama: Llama, + tokenizer: Tokenizer, +} + +pub struct LogitsProcessor { + rng: rand::rngs::StdRng, + temperature: Option<f64>, +} + +impl LogitsProcessor { + pub fn new(seed: u64, temperature: Option<f64>) -> Self { + Self { + rng: rand::rngs::StdRng::seed_from_u64(seed), + temperature, + } + } + + pub fn sample(&mut self, logits: &Tensor) -> Result<u32> { + let logits = logits.to_dtype(DType::F32)?; + let next_token = if let Some(temperature) = self.temperature { + let prs = (&logits / temperature)?.softmax(D::Minus1)?; + let prs: Vec<f32> = prs.to_vec1()?; + let distr = + rand::distributions::WeightedIndex::new(prs).map_err(candle::Error::wrap)?; + distr.sample(&mut self.rng) as u32 + } else { + let logits_v: Vec<f32> = logits.to_vec1()?; + logits_v + .iter() + .enumerate() + .max_by(|(_, u), (_, v)| u.total_cmp(v)) + .map(|(i, _)| i as u32) + .unwrap() + }; + Ok(next_token) + } +} + +impl Model { + fn run(&self, link: &WorkerLink<Worker>, id: HandlerId) -> Result<()> { + let dev = Device::Cpu; + let mut logits_processor = LogitsProcessor::new(299792458, None); + let mut index_pos = 0; + let mut tokens = vec![1u32]; + + for index in 0..self.config.seq_len - 10 { + let context_size = if self.cache.use_kv_cache && index > 0 { + 1 + } else { + tokens.len() + }; + let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; + let input = Tensor::new(ctxt, &dev)?.unsqueeze(0)?; + let logits = self.llama.forward(&input, index_pos)?; + let logits = logits.squeeze(0)?; + index_pos += ctxt.len(); + + let next_token = logits_processor.sample(&logits)?; + tokens.push(next_token); + let token = self.tokenizer.tokens[next_token as usize].clone(); + link.respond(id, Ok(WorkerOutput::Generated(token))); + } + Ok(()) + } +} + +impl Config { + fn from_reader<R: std::io::Read>(r: &mut R) -> Result<Self> { + let dim = read_i32(r)? as usize; + let hidden_dim = read_i32(r)? as usize; + let n_layers = read_i32(r)? as usize; + let n_heads = read_i32(r)? as usize; + let n_kv_heads = read_i32(r)? as usize; + let vocab_size = read_i32(r)? as usize; + let seq_len = read_i32(r)? as usize; + Ok(Self { + dim, + hidden_dim, + n_layers, + n_heads, + n_kv_heads, + vocab_size, + seq_len, + norm_eps: 1e-5, + }) + } + + pub fn head_size(&self) -> usize { + self.dim / self.n_heads + } +} + +struct TransformerWeights { + // token embedding table + token_embedding_table: Tensor, // (vocab_size, dim) + // weights for rmsnorms + rms_att_weight: Tensor, // (layer, dim) rmsnorm weights + rms_ffn_weight: Tensor, // (layer, dim) + // weights for matmuls + wq: Tensor, // (layer, dim, dim) + wk: Tensor, // (layer, dim, dim) + wv: Tensor, // (layer, dim, dim) + wo: Tensor, // (layer, dim, dim) + // weights for ffn + w1: Tensor, // (layer, hidden_dim, dim) + w2: Tensor, // (layer, dim, hidden_dim) + w3: Tensor, // (layer, hidden_dim, dim) + // final rmsnorm + rms_final_weight: Tensor, // (dim,) + // freq_cis for RoPE relatively positional embeddings + freq_cis_real: Tensor, // (seq_len, head_size/2) + freq_cis_imag: Tensor, // (seq_len, head_size/2) +} + +impl TransformerWeights { + fn from_reader<R: std::io::Read>(r: &mut R, c: &Config, dev: &Device) -> Result<Self> { + let token_embedding_table = read_tensor(r, (c.vocab_size, c.dim), dev)?; + let rms_att_weight = read_tensor(r, (c.n_layers, c.dim), dev)?; + let wq = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; + let wk = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; + let wv = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; + let wo = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; + let rms_ffn_weight = read_tensor(r, (c.n_layers, c.dim), dev)?; + let w1 = read_tensor(r, (c.n_layers, c.hidden_dim, c.dim), dev)?; + let w2 = read_tensor(r, (c.n_layers, c.dim, c.hidden_dim), dev)?; + let w3 = read_tensor(r, (c.n_layers, c.hidden_dim, c.dim), dev)?; + let rms_final_weight = read_tensor(r, c.dim, dev)?; + let head_size = c.head_size(); + let freq_cis_real = read_tensor(r, (c.seq_len, head_size / 2), dev)?; + let freq_cis_imag = read_tensor(r, (c.seq_len, head_size / 2), dev)?; + Ok(Self { + token_embedding_table, + rms_att_weight, + wq, + wk, + wv, + wo, + rms_ffn_weight, + w1, + w2, + w3, + rms_final_weight, + freq_cis_real, + freq_cis_imag, + }) + } + + fn var_builder(&self, cfg: &Config, device: &Device) -> Result<VarBuilder> { + let mut ws = std::collections::HashMap::new(); + let mut insert = |name: &str, t: Tensor| { + ws.insert(name.to_string(), t); + }; + insert("rot.freq_cis_real", self.freq_cis_real.clone()); + insert("rot.freq_cis_imag", self.freq_cis_imag.clone()); + insert( + "model.embed_tokens.weight", + self.token_embedding_table.clone(), + ); + insert("lm_head.weight", self.token_embedding_table.clone()); + insert("model.norm.weight", self.rms_final_weight.clone()); + for layer in 0..cfg.n_layers { + ws.insert( + format!("model.layers.{layer}.self_attn.q_proj.weight"), + self.wq.i(layer)?, + ); + ws.insert( + format!("model.layers.{layer}.self_attn.k_proj.weight"), + self.wk.i(layer)?, + ); + ws.insert( + format!("model.layers.{layer}.self_attn.v_proj.weight"), + self.wv.i(layer)?, + ); + ws.insert( + format!("model.layers.{layer}.self_attn.o_proj.weight"), + self.wo.i(layer)?, + ); + ws.insert( + format!("model.layers.{layer}.mlp.gate_proj.weight"), + self.w1.i(layer)?, + ); + ws.insert( + format!("model.layers.{layer}.mlp.down_proj.weight"), + self.w2.i(layer)?, + ); + ws.insert( + format!("model.layers.{layer}.mlp.up_proj.weight"), + self.w3.i(layer)?, + ); + ws.insert( + format!("model.layers.{layer}.input_layernorm.weight"), + self.rms_att_weight.i(layer)?, + ); + ws.insert( + format!("model.layers.{layer}.post_attention_layernorm.weight"), + self.rms_ffn_weight.i(layer)?, + ); + } + let vb = VarBuilder::from_tensors(ws, DType::F32, device); + Ok(vb) + } +} + +impl Model { + fn load(md: ModelData) -> Result<Self> { + let dev = Device::Cpu; + let mut model = std::io::Cursor::new(md.model); + let config = Config::from_reader(&mut model)?; + let weights = TransformerWeights::from_reader(&mut model, &config, &dev)?; + let vb = weights.var_builder(&config, &dev)?; + let cache = Cache::new(true, &config, vb.pp("rot"))?; + let llama = Llama::load(vb, &cache, &config)?; + let mut tokenizer = std::io::Cursor::new(md.tokenizer); + let tokenizer = Tokenizer::from_reader(&mut tokenizer, &config)?; + Ok(Self { + cache, + config, + llama, + tokenizer, + }) + } +} + +pub struct Worker { + link: WorkerLink<Self>, + model: Option<Model>, +} + +#[derive(Serialize, Deserialize)] +pub enum WorkerInput { + ModelData(ModelData), + Run, +} + +#[derive(Serialize, Deserialize)] +pub enum WorkerOutput { + Generated(String), + GenerationDone(std::result::Result<(), String>), + WeightsLoaded, +} + +impl yew_agent::Worker for Worker { + type Input = WorkerInput; + type Message = (); + type Output = std::result::Result<WorkerOutput, String>; + type Reach = Public<Self>; + + fn create(link: WorkerLink<Self>) -> Self { + Self { link, model: None } + } + + fn update(&mut self, _msg: Self::Message) { + // no messaging + } + + fn handle_input(&mut self, msg: Self::Input, id: HandlerId) { + let output = match msg { + WorkerInput::ModelData(md) => match Model::load(md) { + Ok(model) => { + self.model = Some(model); + Ok(WorkerOutput::WeightsLoaded) + } + Err(err) => Err(format!("model creation error {err:?}")), + }, + WorkerInput::Run => match &self.model { + None => Err("model has not been set yet".to_string()), + Some(model) => { + let result = model.run(&self.link, id).map_err(|e| e.to_string()); + Ok(WorkerOutput::GenerationDone(result)) + } + }, + }; + self.link.respond(id, output); + } + + fn name_of_resource() -> &'static str { + "worker.js" + } + + fn resource_path_is_relative() -> bool { + true + } +} diff --git a/candle-wasm-examples/whisper/Cargo.toml b/candle-wasm-examples/whisper/Cargo.toml new file mode 100644 index 00000000..55b507db --- /dev/null +++ b/candle-wasm-examples/whisper/Cargo.toml @@ -0,0 +1,52 @@ +[package] +name = "candle-wasm-example-whisper" +version = "0.1.0" +edition = "2021" + +description = "Wasm example for the candle ML framework." +repository = "https://github.com/LaurentMazare/candle" +keywords = ["blas", "tensor", "machine-learning"] +categories = ["science"] +license = "MIT/Apache-2.0" +readme = "README.md" + +[dependencies] +candle = { path = "../../candle-core" } +candle-nn = { path = "../../candle-nn" } +num-traits = { workspace = true } +tokenizers = { workspace = true, features = ["unstable_wasm"] } + +# App crates. +anyhow = { workspace = true } +log = { workspace = true } +rand = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +wav = { workspace = true } + +# Wasm specific crates. +getrandom = { version = "0.2", features = ["js"] } +gloo = "0.8" +js-sys = "0.3.64" +wasm-bindgen = "0.2.87" +wasm-bindgen-futures = "0.4.37" +wasm-logger = "0.2" +yew-agent = "0.2.0" +yew = { version = "0.20.0", features = ["csr"] } + +[dependencies.web-sys] +version = "0.3.64" +features = [ + 'Blob', + 'Document', + 'Element', + 'HtmlElement', + 'Node', + 'Window', + 'Request', + 'RequestCache', + 'RequestInit', + 'RequestMode', + 'Response', + 'Performance', +] diff --git a/candle-wasm-examples/whisper/index.html b/candle-wasm-examples/whisper/index.html new file mode 100644 index 00000000..7a21c4f2 --- /dev/null +++ b/candle-wasm-examples/whisper/index.html @@ -0,0 +1,24 @@ +<!DOCTYPE html> +<html lang="en"> + <head> + <meta charset="utf-8" /> + <title>Welcome to Candle!</title> + + <link data-trunk rel="copy-file" href="jfk.wav" /> + <link data-trunk rel="copy-file" href="mm0.wav" /> + <link data-trunk rel="copy-file" href="a13.wav" /> + <link data-trunk rel="copy-file" href="gb0.wav" /> + <link data-trunk rel="copy-file" href="gb1.wav" /> + <link data-trunk rel="copy-file" href="hp0.wav" /> + <link data-trunk rel="copy-file" href="tokenizer.en.json" /> + <link data-trunk rel="copy-file" href="mel_filters.safetensors" /> + <link data-trunk rel="copy-file" href="tiny.en.safetensors" /> + <link data-trunk rel="rust" href="Cargo.toml" data-bin="app" data-type="main" /> + <link data-trunk rel="rust" href="Cargo.toml" data-bin="worker" data-type="worker" /> + + <link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Roboto:300,300italic,700,700italic"> + <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/normalize/8.0.1/normalize.css"> + <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/milligram/1.4.1/milligram.css"> + </head> + <body></body> +</html> diff --git a/candle-wasm-examples/whisper/main.js b/candle-wasm-examples/whisper/main.js new file mode 100644 index 00000000..c27e0d60 --- /dev/null +++ b/candle-wasm-examples/whisper/main.js @@ -0,0 +1,6 @@ +import init, { run_app } from './pkg/candle_wasm_example_whisper.js'; +async function main() { + await init('/pkg/candle_wasm_example_whisper_bg.wasm'); + run_app(); +} +main() diff --git a/candle-wasm-examples/whisper/src/app.rs b/candle-wasm-examples/whisper/src/app.rs new file mode 100644 index 00000000..23519ebd --- /dev/null +++ b/candle-wasm-examples/whisper/src/app.rs @@ -0,0 +1,238 @@ +use crate::console_log; +use crate::worker::{ModelData, Segment, Worker, WorkerInput, WorkerOutput}; +use js_sys::Date; +use wasm_bindgen::prelude::*; +use wasm_bindgen_futures::JsFuture; +use yew::{html, Component, Context, Html}; +use yew_agent::{Bridge, Bridged}; + +const SAMPLE_NAMES: [&str; 6] = [ + "jfk.wav", "a13.wav", "gb0.wav", "gb1.wav", "hp0.wav", "mm0.wav", +]; + +async fn fetch_url(url: &str) -> Result<Vec<u8>, JsValue> { + use web_sys::{Request, RequestCache, RequestInit, RequestMode, Response}; + let window = web_sys::window().ok_or("window")?; + let mut opts = RequestInit::new(); + let opts = opts + .method("GET") + .mode(RequestMode::Cors) + .cache(RequestCache::NoCache); + + let request = Request::new_with_str_and_init(url, opts)?; + + let resp_value = JsFuture::from(window.fetch_with_request(&request)).await?; + + // `resp_value` is a `Response` object. + assert!(resp_value.is_instance_of::<Response>()); + let resp: Response = resp_value.dyn_into()?; + let data = JsFuture::from(resp.blob()?).await?; + let blob = web_sys::Blob::from(data); + let array_buffer = JsFuture::from(blob.array_buffer()).await?; + let data = js_sys::Uint8Array::new(&array_buffer).to_vec(); + Ok(data) +} + +pub enum Msg { + Run(usize), + UpdateStatus(String), + SetDecoder(ModelData), + WorkerInMsg(WorkerInput), + WorkerOutMsg(Result<WorkerOutput, String>), +} + +pub struct CurrentDecode { + start_time: Option<f64>, +} + +pub struct App { + status: String, + segments: Vec<Segment>, + current_decode: Option<CurrentDecode>, + worker: Box<dyn Bridge<Worker>>, +} + +async fn model_data_load() -> Result<ModelData, JsValue> { + let tokenizer = fetch_url("tokenizer.en.json").await?; + let mel_filters = fetch_url("mel_filters.safetensors").await?; + let weights = fetch_url("tiny.en.safetensors").await?; + console_log!("{}", weights.len()); + Ok(ModelData { + tokenizer, + mel_filters, + weights, + }) +} + +fn performance_now() -> Option<f64> { + let window = web_sys::window()?; + let performance = window.performance()?; + Some(performance.now() / 1000.) +} + +impl Component for App { + type Message = Msg; + type Properties = (); + + fn create(ctx: &Context<Self>) -> Self { + let status = "loading weights".to_string(); + let cb = { + let link = ctx.link().clone(); + move |e| link.send_message(Self::Message::WorkerOutMsg(e)) + }; + let worker = Worker::bridge(std::rc::Rc::new(cb)); + Self { + status, + segments: vec![], + current_decode: None, + worker, + } + } + + fn rendered(&mut self, ctx: &Context<Self>, first_render: bool) { + if first_render { + ctx.link().send_future(async { + match model_data_load().await { + Err(err) => { + let status = format!("{err:?}"); + Msg::UpdateStatus(status) + } + Ok(model_data) => Msg::SetDecoder(model_data), + } + }); + } + } + + fn update(&mut self, ctx: &Context<Self>, msg: Self::Message) -> bool { + match msg { + Msg::SetDecoder(md) => { + self.status = "weights loaded succesfully!".to_string(); + console_log!("loaded weights"); + self.worker.send(WorkerInput::ModelData(md)); + true + } + Msg::Run(sample_index) => { + let sample = SAMPLE_NAMES[sample_index]; + if self.current_decode.is_some() { + self.status = "already decoding some sample at the moment".to_string() + } else { + let start_time = performance_now(); + self.current_decode = Some(CurrentDecode { start_time }); + self.status = format!("decoding {sample}"); + self.segments.clear(); + ctx.link().send_future(async move { + match fetch_url(sample).await { + Err(err) => { + let output = Err(format!("decoding error: {err:?}")); + // Mimic a worker output to so as to release current_decode + Msg::WorkerOutMsg(output) + } + Ok(wav_bytes) => { + Msg::WorkerInMsg(WorkerInput::DecodeTask { wav_bytes }) + } + } + }) + } + // + true + } + Msg::WorkerOutMsg(output) => { + let dt = self.current_decode.as_ref().and_then(|current_decode| { + current_decode.start_time.and_then(|start_time| { + performance_now().map(|stop_time| stop_time - start_time) + }) + }); + self.current_decode = None; + match output { + Ok(WorkerOutput::WeightsLoaded) => self.status = "weights loaded!".to_string(), + Ok(WorkerOutput::Decoded(segments)) => { + self.status = match dt { + None => "decoding succeeded!".to_string(), + Some(dt) => format!("decoding succeeded in {:.2}s", dt), + }; + self.segments = segments; + } + Err(err) => { + self.status = format!("decoding error {err:?}"); + } + } + true + } + Msg::WorkerInMsg(inp) => { + self.worker.send(inp); + true + } + Msg::UpdateStatus(status) => { + self.status = status; + true + } + } + } + + fn view(&self, ctx: &Context<Self>) -> Html { + html! { + <div> + <table> + <thead> + <tr> + <th>{"Sample"}</th> + <th></th> + <th></th> + </tr> + </thead> + <tbody> + { + SAMPLE_NAMES.iter().enumerate().map(|(i, name)| { html! { + <tr> + <th>{name}</th> + <th><audio controls=true src={format!("./{name}")}></audio></th> + <th><button class="button" onclick={ctx.link().callback(move |_| Msg::Run(i))}> { "run" }</button></th> + </tr> + } + }).collect::<Html>() + } + </tbody> + </table> + <h2> + {&self.status} + </h2> + { + if self.current_decode.is_some() { + html! { <progress id="progress-bar" aria-label="decoding…"></progress> } + } else { html!{ + <blockquote> + <p> + { + self.segments.iter().map(|segment| { html! { + <> + <i> + { + format!("{:.2}s-{:.2}s: (avg-logprob: {:.4}, no-speech-prob: {:.4})", + segment.start, + segment.start + segment.duration, + segment.dr.avg_logprob, + segment.dr.no_speech_prob, + ) + } + </i> + <br/ > + {&segment.dr.text} + <br/ > + </> + } }).collect::<Html>() + } + </p> + </blockquote> + } + } + } + + // Display the current date and time the page was rendered + <p class="footer"> + { "Rendered: " } + { String::from(Date::new_0().to_string()) } + </p> + </div> + } + } +} diff --git a/candle-wasm-examples/whisper/src/audio.rs b/candle-wasm-examples/whisper/src/audio.rs new file mode 100644 index 00000000..5b414368 --- /dev/null +++ b/candle-wasm-examples/whisper/src/audio.rs @@ -0,0 +1,217 @@ +// Audio processing code, adapted from whisper.cpp +// https://github.com/ggerganov/whisper.cpp +use super::worker; + +pub trait Float: num_traits::Float + num_traits::FloatConst + num_traits::NumAssign {} + +impl Float for f32 {} +impl Float for f64 {} + +// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2357 +fn fft<T: Float>(inp: &[T]) -> Vec<T> { + let n = inp.len(); + let zero = T::zero(); + if n == 1 { + return vec![inp[0], zero]; + } + if n % 2 == 1 { + return dft(inp); + } + let mut out = vec![zero; n * 2]; + + let mut even = vec![]; + even.reserve(n / 2); + let mut odd = vec![]; + odd.reserve(n / 2); + + for (i, &inp) in inp.iter().enumerate() { + if i % 2 == 0 { + even.push(inp) + } else { + odd.push(inp); + } + } + + let even_fft = fft(&even); + let odd_fft = fft(&odd); + + let two_pi = T::PI() + T::PI(); + let n_t = T::from(n).unwrap(); + for k in 0..n / 2 { + let k_t = T::from(k).unwrap(); + let theta = two_pi * k_t / n_t; + let re = theta.cos(); + let im = -theta.sin(); + + let re_odd = odd_fft[2 * k]; + let im_odd = odd_fft[2 * k + 1]; + + out[2 * k] = even_fft[2 * k] + re * re_odd - im * im_odd; + out[2 * k + 1] = even_fft[2 * k + 1] + re * im_odd + im * re_odd; + + out[2 * (k + n / 2)] = even_fft[2 * k] - re * re_odd + im * im_odd; + out[2 * (k + n / 2) + 1] = even_fft[2 * k + 1] - re * im_odd - im * re_odd; + } + out +} + +// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2337 +fn dft<T: Float>(inp: &[T]) -> Vec<T> { + let zero = T::zero(); + let n = inp.len(); + let two_pi = T::PI() + T::PI(); + + let mut out = Vec::new(); + out.reserve(2 * n); + let n_t = T::from(n).unwrap(); + for k in 0..n { + let k_t = T::from(k).unwrap(); + let mut re = zero; + let mut im = zero; + + for (j, &inp) in inp.iter().enumerate() { + let j_t = T::from(j).unwrap(); + let angle = two_pi * k_t * j_t / n_t; + re += inp * angle.cos(); + im -= inp * angle.sin(); + } + + out.push(re); + out.push(im); + } + out +} + +#[allow(clippy::too_many_arguments)] +// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2414 +fn log_mel_spectrogram_w<T: Float>( + ith: usize, + hann: &[T], + samples: &[T], + filters: &[T], + fft_size: usize, + fft_step: usize, + speed_up: bool, + n_len: usize, + n_mel: usize, + n_threads: usize, +) -> Vec<T> { + let n_fft = if speed_up { + 1 + fft_size / 4 + } else { + 1 + fft_size / 2 + }; + + let zero = T::zero(); + let half = T::from(0.5).unwrap(); + let mut fft_in = vec![zero; fft_size]; + let mut mel = vec![zero; n_len * n_mel]; + + for i in (ith..n_len).step_by(n_threads) { + let offset = i * fft_step; + + // apply Hanning window + for j in 0..fft_size { + fft_in[j] = if offset + j < samples.len() { + hann[j] * samples[offset + j] + } else { + zero + } + } + + // FFT -> mag^2 + let mut fft_out: Vec<T> = fft(&fft_in); + + for j in 0..fft_size { + fft_out[j] = fft_out[2 * j] * fft_out[2 * j] + fft_out[2 * j + 1] * fft_out[2 * j + 1]; + } + for j in 1..fft_size / 2 { + let v = fft_out[fft_size - j]; + fft_out[j] += v; + } + + if speed_up { + // scale down in the frequency domain results in a speed up in the time domain + for j in 0..n_fft { + fft_out[j] = half * (fft_out[2 * j] + fft_out[2 * j + 1]); + } + } + + // mel spectrogram + for j in 0..n_mel { + let mut sum = zero; + for k in 0..n_fft { + sum += fft_out[k] * filters[j * n_fft + k]; + } + mel[j * n_len + i] = T::max(sum, T::from(1e-10).unwrap()).log10(); + } + } + mel +} + +fn log_mel_spectrogram_<T: Float + std::fmt::Display>( + samples: &[T], + filters: &[T], + fft_size: usize, + fft_step: usize, + n_mel: usize, + speed_up: bool, +) -> Vec<T> { + let zero = T::zero(); + let two_pi = T::PI() + T::PI(); + let half = T::from(0.5).unwrap(); + let one = T::from(1.0).unwrap(); + let four = T::from(4.0).unwrap(); + let fft_size_t = T::from(fft_size).unwrap(); + + let hann: Vec<T> = (0..fft_size) + .map(|i| half * (one - ((two_pi * T::from(i).unwrap()) / fft_size_t).cos())) + .collect(); + let n_len = samples.len() / fft_step; + + // pad audio with at least one extra chunk of zeros + let pad = 100 * worker::CHUNK_LENGTH / 2; + let n_len = if n_len % pad != 0 { + (n_len / pad + 1) * pad + } else { + n_len + }; + let n_len = n_len + pad; + let samples = { + let mut samples_padded = samples.to_vec(); + let to_add = n_len * fft_step - samples.len(); + samples_padded.extend(std::iter::repeat(zero).take(to_add)); + samples_padded + }; + + // Use a single thread for now. + let mut mel = log_mel_spectrogram_w( + 0, &hann, &samples, filters, fft_size, fft_step, speed_up, n_len, n_mel, 1, + ); + let mmax = mel + .iter() + .max_by(|&u, &v| u.partial_cmp(v).unwrap_or(std::cmp::Ordering::Greater)) + .copied() + .unwrap_or(zero) + - T::from(8).unwrap(); + for m in mel.iter_mut() { + let v = T::max(*m, mmax); + *m = v / four + one + } + mel +} + +pub fn pcm_to_mel<T: Float + std::fmt::Display>( + samples: &[T], + filters: &[T], +) -> anyhow::Result<Vec<T>> { + let mel = log_mel_spectrogram_( + samples, + filters, + worker::N_FFT, + worker::HOP_LENGTH, + worker::N_MELS, + false, + ); + Ok(mel) +} diff --git a/candle-wasm-examples/whisper/src/bin/app.rs b/candle-wasm-examples/whisper/src/bin/app.rs new file mode 100644 index 00000000..89efa7f7 --- /dev/null +++ b/candle-wasm-examples/whisper/src/bin/app.rs @@ -0,0 +1,4 @@ +fn main() { + wasm_logger::init(wasm_logger::Config::new(log::Level::Trace)); + yew::Renderer::<candle_wasm_example_whisper::App>::new().render(); +} diff --git a/candle-wasm-examples/whisper/src/bin/worker.rs b/candle-wasm-examples/whisper/src/bin/worker.rs new file mode 100644 index 00000000..b8c16b56 --- /dev/null +++ b/candle-wasm-examples/whisper/src/bin/worker.rs @@ -0,0 +1,4 @@ +use yew_agent::PublicWorker; +fn main() { + candle_wasm_example_whisper::Worker::register(); +} diff --git a/candle-wasm-examples/whisper/src/lib.rs b/candle-wasm-examples/whisper/src/lib.rs new file mode 100644 index 00000000..b47d43ca --- /dev/null +++ b/candle-wasm-examples/whisper/src/lib.rs @@ -0,0 +1,31 @@ +#![allow(dead_code)] + +pub const WITH_TIMER: bool = true; + +struct Timer { + label: &'static str, +} + +impl Timer { + fn new(label: &'static str) -> Self { + if WITH_TIMER { + web_sys::console::time_with_label(label); + } + Self { label } + } +} + +impl Drop for Timer { + fn drop(&mut self) { + if WITH_TIMER { + web_sys::console::time_end_with_label(self.label) + } + } +} + +mod app; +mod audio; +mod model; +mod worker; +pub use app::App; +pub use worker::Worker; diff --git a/candle-wasm-examples/whisper/src/model.rs b/candle-wasm-examples/whisper/src/model.rs new file mode 100644 index 00000000..97eff839 --- /dev/null +++ b/candle-wasm-examples/whisper/src/model.rs @@ -0,0 +1,421 @@ +#![allow(dead_code)] +// We use anyhow rather than candle errors as it provides better support for getting the backtrace +// back when using RUST_LIB_BACKTRACE=1. +use anyhow::Result; +use candle::{Device, Tensor}; +use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder}; +use serde::Deserialize; + +// The names in comments correspond to the original implementation: +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17 +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct Config { + pub num_mel_bins: usize, // n_mels + pub max_source_positions: usize, // n_audio_ctx + pub d_model: usize, // n_audio_state + pub encoder_attention_heads: usize, // n_audio_head + pub encoder_layers: usize, // n_audio_layer + pub vocab_size: usize, // n_vocab + pub max_target_positions: usize, // n_text_ctx + // pub n_text_state: usize, + pub decoder_attention_heads: usize, // n_text_head + pub decoder_layers: usize, // n_text_layer +} + +impl Config { + pub fn tiny_en() -> Self { + Self { + num_mel_bins: 80, + vocab_size: 51864, + max_source_positions: 1500, + d_model: 384, + encoder_attention_heads: 6, + encoder_layers: 4, + max_target_positions: 448, + // n_text_state: 384, + decoder_attention_heads: 6, + decoder_layers: 4, + } + } +} + +// The struct below is duplicated from candle_nn::Linear so that it's easier to add some wasm +// specific monitoring. +#[derive(Debug)] +struct Linear { + weight: Tensor, + bias: Option<Tensor>, +} + +impl Linear { + fn new(weight: Tensor, bias: Option<Tensor>) -> Self { + Self { weight, bias } + } + + fn forward(&self, x: &Tensor) -> candle::Result<Tensor> { + let _timer = crate::Timer::new("Linear::forward"); + let w = match x.dims() { + &[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?, + _ => self.weight.t()?, + }; + let x = { + let _timer = crate::Timer::new("Linear::matmul"); + x.matmul(&w)? + }; + match &self.bias { + None => Ok(x), + Some(bias) => x.broadcast_add(bias), + } + } +} + +fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { + let embeddings = vb.get((vocab_size, hidden_size), "weight")?; + Ok(Embedding::new(embeddings, hidden_size)) +} + +fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> { + let weight = vb.get((size2, size1), "weight")?; + let bias = vb.get(size2, "bias")?; + Ok(Linear::new(weight, Some(bias))) +} + +fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> { + let weight = vb.get((size2, size1), "weight")?; + Ok(Linear::new(weight, None)) +} + +fn conv1d( + in_channels: usize, + out_channels: usize, + kernel_size: usize, + config: Conv1dConfig, + vb: VarBuilder, +) -> Result<Conv1d> { + let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?; + let bias = vb.get(out_channels, "bias")?; + Ok(Conv1d::new(weight, Some(bias), config)) +} + +fn conv1d_no_bias( + in_channels: usize, + out_channels: usize, + kernel_size: usize, + config: Conv1dConfig, + vb: VarBuilder, +) -> Result<Conv1d> { + let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?; + Ok(Conv1d::new(weight, None, config)) +} + +struct Dropout { + pr: f64, +} + +impl Dropout { + fn new(pr: f64) -> Self { + Self { pr } + } + + fn forward(&self, x: &Tensor) -> Result<Tensor> { + // TODO + Ok(x.clone()) + } +} + +fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> { + let weight = vb.get(size, "weight")?; + let bias = vb.get(size, "bias")?; + Ok(LayerNorm::new(weight, bias, 1e-5)) +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62 +struct MultiHeadAttention { + query: Linear, + key: Linear, + value: Linear, + out: Linear, + n_head: usize, +} + +impl MultiHeadAttention { + fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> { + let query = linear(n_state, n_state, vb.pp("q_proj"))?; + let value = linear(n_state, n_state, vb.pp("v_proj"))?; + let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?; + let out = linear(n_state, n_state, vb.pp("out_proj"))?; + Ok(Self { + query, + key, + value, + out, + n_head, + }) + } + + fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> { + let _timer = crate::Timer::new("MultiHeadAttention::forward"); + let q = self.query.forward(x)?; + let k = self.key.forward(xa.unwrap_or(x))?; + let v = self.value.forward(xa.unwrap_or(x))?; + let wv = self.qkv_attention(&q, &k, &v, mask)?; + let out = self.out.forward(&wv)?; + Ok(out) + } + + fn reshape_head(&self, x: &Tensor) -> Result<Tensor> { + let (n_batch, n_ctx, n_state) = x.dims3()?; + let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head]; + Ok(x.reshape(target_dims)?.transpose(1, 2)?) + } + + fn qkv_attention( + &self, + q: &Tensor, + k: &Tensor, + v: &Tensor, + mask: Option<&Tensor>, + ) -> Result<Tensor> { + let (_, n_ctx, n_state) = q.dims3()?; + let scale = ((n_state / self.n_head) as f64).powf(-0.25); + let q = { + let _timer = crate::Timer::new("q::reshape"); + (self.reshape_head(q)? * scale)? + }; + let k = { + let _timer = crate::Timer::new("k::reshape"); + (self.reshape_head(k)?.transpose(2, 3)? * scale)? + }; + let v = { + let _timer = crate::Timer::new("v::reshape-contiguous"); + self.reshape_head(v)?.contiguous()? + }; + let mut qk = { + let _timer = crate::Timer::new("qk::matmul"); + q.matmul(&k)? + }; + if let Some(mask) = mask { + let mask = mask.narrow(0, 0, n_ctx)?.narrow(1, 0, n_ctx)?; + qk = qk.broadcast_add(&mask)? + } + let w = { + let _timer = crate::Timer::new("qk::softmax"); + qk.softmax(candle::D::Minus1)? + }; + let wv = { + let _timer = crate::Timer::new("wv::matmul"); + w.matmul(&v)?.transpose(1, 2)?.flatten_from(2)? + }; + Ok(wv) + } +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111 +struct ResidualAttentionBlock { + attn: MultiHeadAttention, + attn_ln: LayerNorm, + cross_attn: Option<(MultiHeadAttention, LayerNorm)>, + mlp_linear1: Linear, + mlp_linear2: Linear, + mlp_ln: LayerNorm, +} + +impl ResidualAttentionBlock { + fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> { + let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?; + let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?; + let cross_attn = if ca { + let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?; + let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?; + Some((cross_attn, cross_attn_ln)) + } else { + None + }; + let n_mlp = n_state * 4; + let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?; + let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?; + let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?; + Ok(Self { + attn, + attn_ln, + cross_attn, + mlp_linear1, + mlp_linear2, + mlp_ln, + }) + } + + fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> { + let _timer = crate::Timer::new("ResidualAttentionBlock::forward"); + let attn = self.attn.forward(&self.attn_ln.forward(x)?, None, mask)?; + let mut x = (x + attn)?; + if let Some((attn, ln)) = &self.cross_attn { + x = (&x + attn.forward(&ln.forward(&x)?, xa, None)?)?; + } + let mlp = self.mlp_linear2.forward( + &self + .mlp_linear1 + .forward(&self.mlp_ln.forward(&x)?)? + .gelu()?, + )?; + Ok((x + mlp)?) + } +} + +fn sinusoids(length: usize, channels: usize) -> Result<Tensor> { + let max_timescale = 10000f32; + let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32; + let inv_timescales: Vec<_> = (0..channels / 2) + .map(|i| (i as f32 * (-log_timescale_increment)).exp()) + .collect(); + let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?; + let arange = Tensor::arange(0, length as u32, &Device::Cpu)? + .to_dtype(candle::DType::F32)? + .unsqueeze(1)?; + let sh = (length, channels / 2); + let scaled_time = (arange.broadcast_as(sh)? * inv_timescales.broadcast_as(sh)?)?; + let sincos = Tensor::cat(&[scaled_time.sin()?, scaled_time.cos()?], 1)?; + Ok(sincos) +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143 +pub struct AudioEncoder { + conv1: Conv1d, + conv2: Conv1d, + positional_embedding: Tensor, + blocks: Vec<ResidualAttentionBlock>, + ln_post: LayerNorm, +} + +impl AudioEncoder { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let n_state = cfg.d_model; + let n_head = cfg.encoder_attention_heads; + let n_ctx = cfg.max_source_positions; + let cfg1 = Conv1dConfig { + padding: 1, + stride: 1, + }; + let cfg2 = Conv1dConfig { + padding: 1, + stride: 2, + }; + let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?; + let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?; + let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?; + let blocks = (0..cfg.encoder_layers) + .map(|i| { + ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(&format!("layers.{i}"))) + }) + .collect::<Result<Vec<_>>>()?; + let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?; + Ok(Self { + conv1, + conv2, + positional_embedding, + blocks, + ln_post, + }) + } + pub fn forward(&self, x: &Tensor) -> Result<Tensor> { + let _timer = crate::Timer::new("AudioEncoder::forward"); + let x = { + let _timer = crate::Timer::new("conv1::forward"); + self.conv1.forward(x)?.gelu()? + }; + let x = { + let _timer = crate::Timer::new("conv2::forward"); + self.conv2.forward(&x)?.gelu()? + }; + let x = x.transpose(1, 2)?; + let (_bsize, seq_len, _hidden) = x.dims3()?; + let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?; + let mut x = x.broadcast_add(&positional_embedding)?; + for block in self.blocks.iter() { + x = block.forward(&x, None, None)? + } + let x = self.ln_post.forward(&x)?; + Ok(x) + } +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176 +pub struct TextDecoder { + token_embedding: Embedding, + positional_embedding: Tensor, + blocks: Vec<ResidualAttentionBlock>, + ln: LayerNorm, + mask: Tensor, +} + +impl TextDecoder { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let _timer = crate::Timer::new("TextDecoder::forward"); + let n_state = cfg.d_model; + let n_head = cfg.decoder_attention_heads; + let n_ctx = cfg.max_target_positions; + let token_embedding = embedding(cfg.vocab_size, n_state, vb.pp("embed_tokens"))?; + let positional_embedding = vb.get((n_ctx, n_state), "embed_positions.weight")?; + let blocks = (0..cfg.decoder_layers) + .map(|i| { + ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(&format!("layers.{i}"))) + }) + .collect::<Result<Vec<_>>>()?; + let ln = layer_norm(n_state, vb.pp("layer_norm"))?; + let mask: Vec<_> = (0..n_ctx) + .flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 })) + .collect(); + let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?; + + Ok(Self { + token_embedding, + positional_embedding, + blocks, + ln, + mask, + }) + } + + pub fn forward(&self, x: &Tensor, xa: &Tensor) -> Result<Tensor> { + let x_dims = x.dims(); + let last = x_dims[x_dims.len() - 1]; + let token_embedding = self.token_embedding.forward(x)?; + let positional_embedding = self.positional_embedding.narrow(0, 0, last)?; + let mut x = token_embedding.broadcast_add(&positional_embedding)?; + for block in self.blocks.iter() { + x = block.forward(&x, Some(xa), Some(&self.mask))?; + } + let x = self.ln.forward(&x)?; + let w = self + .token_embedding + .embeddings() + .broadcast_left(x_dims[0])?; + let logits = x.matmul(&w.t()?)?; + Ok(logits) + } +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221 +pub struct Whisper { + pub encoder: AudioEncoder, + pub decoder: TextDecoder, + pub config: Config, +} + +impl Whisper { + pub fn load(vb: &VarBuilder, config: Config) -> Result<Self> { + let encoder = AudioEncoder::load(vb.pp("model.encoder"), &config)?; + let decoder = TextDecoder::load(vb.pp("model.decoder"), &config)?; + Ok(Self { + encoder, + decoder, + config, + }) + } + + pub fn forward(&self, mel: &Tensor, tokens: &Tensor) -> Result<Tensor> { + let enc = self.encoder.forward(mel)?; + let dec = self.decoder.forward(tokens, &enc)?; + Ok(dec) + } +} diff --git a/candle-wasm-examples/whisper/src/worker.rs b/candle-wasm-examples/whisper/src/worker.rs new file mode 100644 index 00000000..ea64bf02 --- /dev/null +++ b/candle-wasm-examples/whisper/src/worker.rs @@ -0,0 +1,345 @@ +use crate::model::{Config, Whisper}; +use anyhow::Error as E; +use candle::{safetensors::Load, DType, Device, Tensor}; +use candle_nn::VarBuilder; +use rand::{distributions::Distribution, rngs::StdRng, SeedableRng}; +use serde::{Deserialize, Serialize}; +use tokenizers::Tokenizer; +use wasm_bindgen::prelude::*; +use yew_agent::{HandlerId, Public, WorkerLink}; + +#[wasm_bindgen] +extern "C" { + // Use `js_namespace` here to bind `console.log(..)` instead of just + // `log(..)` + #[wasm_bindgen(js_namespace = console)] + pub fn log(s: &str); +} + +#[macro_export] +macro_rules! console_log { + // Note that this is using the `log` function imported above during + // `bare_bones` + ($($t:tt)*) => ($crate::worker::log(&format_args!($($t)*).to_string())) +} + +pub const DTYPE: DType = DType::F32; + +// Audio parameters. +pub const SAMPLE_RATE: usize = 16000; +pub const N_FFT: usize = 400; +pub const N_MELS: usize = 80; +pub const HOP_LENGTH: usize = 160; +pub const CHUNK_LENGTH: usize = 30; +pub const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk +pub const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input +pub const N_SAMPLES_PER_TOKEN: usize = HOP_LENGTH * 2; // the initial convolutions has stride 2 +pub const FRAMES_PER_SECOND: usize = SAMPLE_RATE / HOP_LENGTH; // 10ms per audio frame +pub const TOKENS_PER_SECOND: usize = SAMPLE_RATE / N_SAMPLES_PER_TOKEN; // 20ms per audio token + +pub const NO_SPEECH_THRESHOLD: f64 = 0.6; +pub const LOGPROB_THRESHOLD: f64 = -1.0; +pub const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]; +pub const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4; + +// Tokenizer dependent bits. +pub const SOT_TOKEN: u32 = 50257; +pub const EOT_TOKEN: u32 = 50256; +pub const NO_SPEECH_TOKEN: u32 = 50361; +pub const NO_TIMESTAMP_TOKEN: u32 = 50362; +// From the _get_suppress_tokens function + 50362 (no timestamp) +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605 +pub const SUPPRESS_TOKENS: [u32; 91] = [ + 1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357, + 366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391, 1635, 1782, + 1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959, + 10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, 16410, 16791, 17992, + 19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409, 34949, 40283, 40493, 40549, + 47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361, 50362, +]; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DecodingResult { + pub tokens: Vec<u32>, + pub text: String, + pub avg_logprob: f64, + pub no_speech_prob: f64, + temperature: f64, + compression_ratio: f64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Segment { + pub start: f64, + pub duration: f64, + pub dr: DecodingResult, +} + +pub struct Decoder { + model: Whisper, + mel_filters: Vec<f32>, + tokenizer: Tokenizer, + suppress_tokens: Tensor, +} + +impl Decoder { + fn new( + model: Whisper, + tokenizer: Tokenizer, + mel_filters: Vec<f32>, + device: &Device, + ) -> anyhow::Result<Self> { + let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32) + .map(|i| { + if SUPPRESS_TOKENS.contains(&i) { + f32::NEG_INFINITY + } else { + 0f32 + } + }) + .collect(); + let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?; + Ok(Self { + model, + mel_filters, + tokenizer, + suppress_tokens, + }) + } + + fn decode(&self, mel: &Tensor, t: f64, rng: &mut StdRng) -> anyhow::Result<DecodingResult> { + let model = &self.model; + let audio_features = model.encoder.forward(mel)?; + console_log!("audio features: {:?}", audio_features.dims()); + let sample_len = model.config.max_target_positions / 2; + let mut sum_logprob = 0f64; + let mut no_speech_prob = f64::NAN; + let mut tokens = vec![SOT_TOKEN]; + for i in 0..sample_len { + let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?; + + // The model expects a batch dim but this inference loop does not handle + // it so we add it at this point. + let tokens_t = tokens_t.unsqueeze(0)?; + let logits = model.decoder.forward(&tokens_t, &audio_features)?; + let logits = logits.squeeze(0)?; + + // Extract the no speech probability on the first iteration by looking at the first + // token logits and the probability for the according token. + if i == 0 { + no_speech_prob = logits + .get(0)? + .softmax(0)? + .get(NO_SPEECH_TOKEN as usize)? + .to_scalar::<f32>()? as f64; + } + + let (seq_len, _) = logits.dims2()?; + let logits = logits + .get(seq_len - 1)? + .broadcast_add(&self.suppress_tokens)?; + let next_token = if t > 0f64 { + let prs = (&logits / t)?.softmax(0)?; + let logits_v: Vec<f32> = prs.to_vec1()?; + let distr = rand::distributions::WeightedIndex::new(&logits_v)?; + distr.sample(rng) as u32 + } else { + let logits_v: Vec<f32> = logits.to_vec1()?; + logits_v + .iter() + .enumerate() + .max_by(|(_, u), (_, v)| u.total_cmp(v)) + .map(|(i, _)| i as u32) + .unwrap() + }; + tokens.push(next_token); + let prob = logits + .softmax(candle::D::Minus1)? + .get(next_token as usize)? + .to_scalar::<f32>()? as f64; + if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions { + break; + } + sum_logprob += prob.ln(); + } + let text = self + .tokenizer + .decode(tokens.clone(), true) + .map_err(E::msg)?; + let avg_logprob = sum_logprob / tokens.len() as f64; + + Ok(DecodingResult { + tokens, + text, + avg_logprob, + no_speech_prob, + temperature: t, + compression_ratio: f64::NAN, + }) + } + + fn decode_with_fallback( + &self, + segment: &Tensor, + rng: &mut StdRng, + ) -> anyhow::Result<DecodingResult> { + for (i, &t) in TEMPERATURES.iter().enumerate() { + let dr: Result<DecodingResult, _> = self.decode(segment, t, rng); + if i == TEMPERATURES.len() - 1 { + return dr; + } + // On errors, we try again with a different temperature. + match dr { + Ok(dr) => { + let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD + || dr.avg_logprob < LOGPROB_THRESHOLD; + if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD { + return Ok(dr); + } + } + Err(err) => { + console_log!("Error running at {t}: {err}") + } + } + } + unreachable!() + } + + fn run(&self, mel: &Tensor) -> anyhow::Result<Vec<Segment>> { + let mut rng = StdRng::seed_from_u64(299792458); + let (_, _, content_frames) = mel.dims3()?; + let mut seek = 0; + let mut segments = vec![]; + while seek < content_frames { + let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64; + let segment_size = usize::min(content_frames - seek, N_FRAMES); + let mel_segment = mel.narrow(2, seek, segment_size)?; + let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64; + let dr = self.decode_with_fallback(&mel_segment, &mut rng)?; + seek += segment_size; + if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD { + console_log!("no speech detected, skipping {seek} {dr:?}"); + continue; + } + let segment = Segment { + start: time_offset, + duration: segment_duration, + dr, + }; + console_log!("{seek}: {segment:?}"); + segments.push(segment) + } + Ok(segments) + } + + fn load(md: ModelData) -> anyhow::Result<Self> { + let device = Device::Cpu; + let tokenizer = Tokenizer::from_bytes(&md.tokenizer).map_err(anyhow::Error::msg)?; + + let mel_filters = candle::safetensors::SafeTensors::deserialize(&md.mel_filters)?; + let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?; + console_log!("loaded mel filters {:?}", mel_filters.shape()); + let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?; + let weights = candle::safetensors::SafeTensors::deserialize(&md.weights)?; + let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device); + let config = Config::tiny_en(); + let whisper = Whisper::load(&vb, config)?; + console_log!("done loading model"); + let decoder = Self::new(whisper, tokenizer, mel_filters, &device)?; + Ok(decoder) + } + + fn convert_and_run(&self, wav_input: &[u8]) -> anyhow::Result<Vec<Segment>> { + let device = Device::Cpu; + let mut wav_input = std::io::Cursor::new(wav_input); + let (header, data) = wav::read(&mut wav_input)?; + console_log!("loaded wav data: {header:?}"); + if header.sampling_rate != SAMPLE_RATE as u32 { + anyhow::bail!("wav file must have a {SAMPLE_RATE} sampling rate"); + } + let data = data.as_sixteen().expect("expected 16 bit wav file"); + let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize] + .iter() + .map(|v| *v as f32 / 32768.) + .collect(); + console_log!("pcm data loaded {}", pcm_data.len()); + let mel = crate::audio::pcm_to_mel(&pcm_data, &self.mel_filters)?; + let mel_len = mel.len(); + let mel = Tensor::from_vec(mel, (1, N_MELS, mel_len / N_MELS), &device)?; + console_log!("loaded mel: {:?}", mel.dims()); + let segments = self.run(&mel)?; + Ok(segments) + } +} + +// Communication to the worker happens through bincode, the model weights and configs are fetched +// on the main thread and transfered via the following structure. +#[derive(Serialize, Deserialize)] +pub struct ModelData { + pub tokenizer: Vec<u8>, + pub mel_filters: Vec<u8>, + pub weights: Vec<u8>, +} + +pub struct Worker { + link: WorkerLink<Self>, + decoder: Option<Decoder>, +} + +#[derive(Serialize, Deserialize)] +pub enum WorkerInput { + ModelData(ModelData), + DecodeTask { wav_bytes: Vec<u8> }, +} + +#[derive(Serialize, Deserialize)] +pub enum WorkerOutput { + Decoded(Vec<Segment>), + WeightsLoaded, +} + +impl yew_agent::Worker for Worker { + type Input = WorkerInput; + type Message = (); + type Output = Result<WorkerOutput, String>; + type Reach = Public<Self>; + + fn create(link: WorkerLink<Self>) -> Self { + Self { + link, + decoder: None, + } + } + + fn update(&mut self, _msg: Self::Message) { + // no messaging + } + + fn handle_input(&mut self, msg: Self::Input, id: HandlerId) { + let output = match msg { + WorkerInput::ModelData(md) => match Decoder::load(md) { + Ok(decoder) => { + self.decoder = Some(decoder); + Ok(WorkerOutput::WeightsLoaded) + } + Err(err) => Err(format!("model creation error {err:?}")), + }, + WorkerInput::DecodeTask { wav_bytes } => match &self.decoder { + None => Err("model has not been set".to_string()), + Some(decoder) => decoder + .convert_and_run(&wav_bytes) + .map(WorkerOutput::Decoded) + .map_err(|e| e.to_string()), + }, + }; + self.link.respond(id, output); + } + + fn name_of_resource() -> &'static str { + "worker.js" + } + + fn resource_path_is_relative() -> bool { + true + } +} |