summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-14 22:06:40 +0100
committerGitHub <noreply@github.com>2023-07-14 22:06:40 +0100
commit2ddda706bde9936cbc8f90142ed4acc43390904e (patch)
tree42ab7a2321b5e128ca277f5369dbd7ebfe02e736
parentd1f5d44c04d084ac96227096bd8a1f7791201b64 (diff)
downloadcandle-2ddda706bde9936cbc8f90142ed4acc43390904e.tar.gz
candle-2ddda706bde9936cbc8f90142ed4acc43390904e.tar.bz2
candle-2ddda706bde9936cbc8f90142ed4acc43390904e.zip
Switch to using trunk. (#171)
-rw-r--r--.gitignore1
-rw-r--r--candle-wasm-example/Cargo.toml4
-rw-r--r--candle-wasm-example/index.html16
-rw-r--r--candle-wasm-example/src/app.rs369
-rw-r--r--candle-wasm-example/src/audio.rs10
-rw-r--r--candle-wasm-example/src/bin/app.rs4
-rw-r--r--candle-wasm-example/src/bin/worker.rs4
-rw-r--r--candle-wasm-example/src/lib.rs12
-rw-r--r--candle-wasm-example/src/worker.rs339
9 files changed, 425 insertions, 334 deletions
diff --git a/.gitignore b/.gitignore
index fa561541..e8d63fbb 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,6 +1,7 @@
# Generated by Cargo
# will have compiled files and executables
debug/
+dist/
target/
# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
diff --git a/candle-wasm-example/Cargo.toml b/candle-wasm-example/Cargo.toml
index a76ce940..0d7f1701 100644
--- a/candle-wasm-example/Cargo.toml
+++ b/candle-wasm-example/Cargo.toml
@@ -10,9 +10,6 @@ categories = ["science"]
license = "MIT/Apache-2.0"
readme = "README.md"
-[lib]
-crate-type = ["cdylib"]
-
[dependencies]
candle = { path = "../candle-core" }
candle-nn = { path = "../candle-nn" }
@@ -34,6 +31,7 @@ js-sys = "0.3.64"
wasm-bindgen = "0.2.87"
wasm-bindgen-futures = "0.4.37"
wasm-logger = "0.2"
+yew-agent = "0.2.0"
yew = { version = "0.20.0", features = ["csr"] }
[dependencies.web-sys]
diff --git a/candle-wasm-example/index.html b/candle-wasm-example/index.html
index a878197c..7a21c4f2 100644
--- a/candle-wasm-example/index.html
+++ b/candle-wasm-example/index.html
@@ -4,13 +4,21 @@
<meta charset="utf-8" />
<title>Welcome to Candle!</title>
- <link data-trunk rel="rust" />
- <link data-trunk rel="css" href="https://cdnjs.cloudflare.com/ajax/libs/milligram/1.4.1/milligram.min.css" />
+ <link data-trunk rel="copy-file" href="jfk.wav" />
+ <link data-trunk rel="copy-file" href="mm0.wav" />
+ <link data-trunk rel="copy-file" href="a13.wav" />
+ <link data-trunk rel="copy-file" href="gb0.wav" />
+ <link data-trunk rel="copy-file" href="gb1.wav" />
+ <link data-trunk rel="copy-file" href="hp0.wav" />
+ <link data-trunk rel="copy-file" href="tokenizer.en.json" />
+ <link data-trunk rel="copy-file" href="mel_filters.safetensors" />
+ <link data-trunk rel="copy-file" href="tiny.en.safetensors" />
+ <link data-trunk rel="rust" href="Cargo.toml" data-bin="app" data-type="main" />
+ <link data-trunk rel="rust" href="Cargo.toml" data-bin="worker" data-type="worker" />
+
<link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Roboto:300,300italic,700,700italic">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/normalize/8.0.1/normalize.css">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/milligram/1.4.1/milligram.css">
-
- <script src="./pkg/bundle.js" defer></script>
</head>
<body></body>
</html>
diff --git a/candle-wasm-example/src/app.rs b/candle-wasm-example/src/app.rs
index 617c838d..5a88ba2e 100644
--- a/candle-wasm-example/src/app.rs
+++ b/candle-wasm-example/src/app.rs
@@ -1,289 +1,15 @@
-use crate::model::{Config, Whisper};
-use anyhow::Error as E;
-use candle::{DType, Device, Tensor};
-use candle_nn::VarBuilder;
+use crate::console_log;
+use crate::worker::{ModelData, Worker, WorkerInput, WorkerOutput};
use js_sys::Date;
-use rand::distributions::Distribution;
-use tokenizers::Tokenizer;
use wasm_bindgen::prelude::*;
use wasm_bindgen_futures::JsFuture;
use yew::{html, Component, Context, Html};
+use yew_agent::{Bridge, Bridged};
const SAMPLE_NAMES: [&str; 6] = [
"jfk.wav", "a13.wav", "gb0.wav", "gb1.wav", "hp0.wav", "mm0.wav",
];
-pub const DTYPE: DType = DType::F32;
-
-// Audio parameters.
-pub const SAMPLE_RATE: usize = 16000;
-pub const N_FFT: usize = 400;
-pub const N_MELS: usize = 80;
-pub const HOP_LENGTH: usize = 160;
-pub const CHUNK_LENGTH: usize = 30;
-pub const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
-pub const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input
-pub const N_SAMPLES_PER_TOKEN: usize = HOP_LENGTH * 2; // the initial convolutions has stride 2
-pub const FRAMES_PER_SECOND: usize = SAMPLE_RATE / HOP_LENGTH; // 10ms per audio frame
-pub const TOKENS_PER_SECOND: usize = SAMPLE_RATE / N_SAMPLES_PER_TOKEN; // 20ms per audio token
-
-pub const NO_SPEECH_THRESHOLD: f64 = 0.6;
-pub const LOGPROB_THRESHOLD: f64 = -1.0;
-pub const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
-pub const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
-
-// Tokenizer dependent bits.
-pub const SOT_TOKEN: u32 = 50257;
-pub const EOT_TOKEN: u32 = 50256;
-pub const NO_SPEECH_TOKEN: u32 = 50361;
-pub const NO_TIMESTAMP_TOKEN: u32 = 50362;
-// From the _get_suppress_tokens function + 50362 (no timestamp)
-// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605
-pub const SUPPRESS_TOKENS: [u32; 91] = [
- 1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357,
- 366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391, 1635, 1782,
- 1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959,
- 10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, 16410, 16791, 17992,
- 19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409, 34949, 40283, 40493, 40549,
- 47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361, 50362,
-];
-
-#[wasm_bindgen]
-extern "C" {
- // Use `js_namespace` here to bind `console.log(..)` instead of just
- // `log(..)`
- #[wasm_bindgen(js_namespace = console)]
- fn log(s: &str);
-}
-
-macro_rules! console_log {
- // Note that this is using the `log` function imported above during
- // `bare_bones`
- ($($t:tt)*) => (log(&format_args!($($t)*).to_string()))
-}
-
-#[derive(Debug, Clone)]
-struct DecodingResult {
- tokens: Vec<u32>,
- text: String,
- avg_logprob: f64,
- no_speech_prob: f64,
- temperature: f64,
- compression_ratio: f64,
-}
-
-#[derive(Debug, Clone)]
-struct Segment {
- start: f64,
- duration: f64,
- dr: DecodingResult,
-}
-
-pub struct Decoder {
- model: Whisper,
- mel_filters: Vec<f32>,
- tokenizer: Tokenizer,
- suppress_tokens: Tensor,
-}
-
-impl Decoder {
- fn new(
- model: Whisper,
- tokenizer: Tokenizer,
- mel_filters: Vec<f32>,
- device: &Device,
- ) -> anyhow::Result<Self> {
- let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
- .map(|i| {
- if SUPPRESS_TOKENS.contains(&i) {
- f32::NEG_INFINITY
- } else {
- 0f32
- }
- })
- .collect();
- let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
- Ok(Self {
- model,
- mel_filters,
- tokenizer,
- suppress_tokens,
- })
- }
-
- fn decode(&self, mel: &Tensor, t: f64) -> anyhow::Result<DecodingResult> {
- let model = &self.model;
- let audio_features = model.encoder.forward(mel)?;
- console_log!("audio features: {:?}", audio_features.dims());
- let sample_len = model.config.max_target_positions / 2;
- let mut sum_logprob = 0f64;
- let mut no_speech_prob = f64::NAN;
- let mut tokens = vec![SOT_TOKEN];
- for i in 0..sample_len {
- let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
-
- // The model expects a batch dim but this inference loop does not handle
- // it so we add it at this point.
- let tokens_t = tokens_t.unsqueeze(0)?;
- let logits = model.decoder.forward(&tokens_t, &audio_features)?;
- let logits = logits.squeeze(0)?;
-
- // Extract the no speech probability on the first iteration by looking at the first
- // token logits and the probability for the according token.
- if i == 0 {
- no_speech_prob = logits
- .get(0)?
- .softmax(0)?
- .get(NO_SPEECH_TOKEN as usize)?
- .to_scalar::<f32>()? as f64;
- }
-
- let (seq_len, _) = logits.shape().r2()?;
- let logits = logits
- .get(seq_len - 1)?
- .broadcast_add(&self.suppress_tokens)?;
- let next_token = if t > 0f64 {
- let prs = (&logits / t)?.softmax(0)?;
- let logits_v: Vec<f32> = prs.to_vec1()?;
- let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
- let mut rng = rand::thread_rng();
- distr.sample(&mut rng) as u32
- } else {
- let logits_v: Vec<f32> = logits.to_vec1()?;
- logits_v
- .iter()
- .enumerate()
- .max_by(|(_, u), (_, v)| u.total_cmp(v))
- .map(|(i, _)| i as u32)
- .unwrap()
- };
- tokens.push(next_token);
- let prob = logits
- .softmax(candle::D::Minus1)?
- .get(next_token as usize)?
- .to_scalar::<f32>()? as f64;
- if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions {
- break;
- }
- sum_logprob += prob.ln();
- }
- let text = self
- .tokenizer
- .decode(tokens.clone(), true)
- .map_err(E::msg)?;
- let avg_logprob = sum_logprob / tokens.len() as f64;
-
- Ok(DecodingResult {
- tokens,
- text,
- avg_logprob,
- no_speech_prob,
- temperature: t,
- compression_ratio: f64::NAN,
- })
- }
-
- fn decode_with_fallback(&self, segment: &Tensor) -> anyhow::Result<DecodingResult> {
- for (i, &t) in TEMPERATURES.iter().enumerate() {
- let dr: Result<DecodingResult, _> = self.decode(segment, t);
- if i == TEMPERATURES.len() - 1 {
- return dr;
- }
- // On errors, we try again with a different temperature.
- match dr {
- Ok(dr) => {
- let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD
- || dr.avg_logprob < LOGPROB_THRESHOLD;
- if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD {
- return Ok(dr);
- }
- }
- Err(err) => {
- console_log!("Error running at {t}: {err}")
- }
- }
- }
- unreachable!()
- }
-
- fn run(&self, mel: &Tensor) -> anyhow::Result<Vec<Segment>> {
- let (_, _, content_frames) = mel.shape().r3()?;
- let mut seek = 0;
- let mut segments = vec![];
- while seek < content_frames {
- let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
- let segment_size = usize::min(content_frames - seek, N_FRAMES);
- let mel_segment = mel.narrow(2, seek, segment_size)?;
- let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
- let dr = self.decode_with_fallback(&mel_segment)?;
- seek += segment_size;
- if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD {
- console_log!("no speech detected, skipping {seek} {dr:?}");
- continue;
- }
- let segment = Segment {
- start: time_offset,
- duration: segment_duration,
- dr,
- };
- console_log!("{seek}: {segment:?}");
- segments.push(segment)
- }
- Ok(segments)
- }
-
- async fn load() -> Result<Self, JsValue> {
- let device = Device::Cpu;
- let tokenizer_config = fetch_url("tokenizer.en.json").await?;
- let tokenizer = Tokenizer::from_bytes(tokenizer_config).map_err(w)?;
-
- let mel_filters = fetch_url("mel_filters.safetensors").await?;
- let mel_filters = candle::safetensors::SafeTensors::from_buffer(&mel_filters).map_err(w)?;
- let mel_filters = mel_filters.tensor("mel_80", &device).map_err(w)?;
- console_log!("loaded mel filters {:?}", mel_filters.shape());
- let mel_filters = mel_filters
- .flatten_all()
- .map_err(w)?
- .to_vec1::<f32>()
- .map_err(w)?;
- let weights = fetch_url("tiny.en.safetensors").await?;
- let weights = candle::safetensors::SafeTensors::from_buffer(&weights).map_err(w)?;
- let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
- let config = Config::tiny_en();
- let whisper = Whisper::load(&vb, config).map_err(w)?;
- console_log!("done loading model");
- let model = Decoder::new(whisper, tokenizer, mel_filters, &device).map_err(w)?;
- Ok(model)
- }
-
- async fn load_and_run(&self, name: &str) -> Result<Vec<Segment>, JsValue> {
- let device = Device::Cpu;
- let wav_input = fetch_url(name).await?;
- let mut wav_input = std::io::Cursor::new(wav_input);
- let (header, data) = wav::read(&mut wav_input).map_err(w)?;
- console_log!("loaded wav data: {header:?}");
- if header.sampling_rate != SAMPLE_RATE as u32 {
- Err(format!(
- "wav file must have a {} sampling rate",
- SAMPLE_RATE
- ))?
- }
- let data = data.as_sixteen().expect("expected 16 bit wav file");
- let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
- .iter()
- .map(|v| *v as f32 / 32768.)
- .collect();
- console_log!("pcm data loaded {}", pcm_data.len());
- let mel = crate::audio::pcm_to_mel(&pcm_data, &self.mel_filters).map_err(w)?;
- let mel_len = mel.len();
- let mel = Tensor::from_vec(mel, (1, N_MELS, mel_len / N_MELS), &device).map_err(w)?;
- console_log!("loaded mel: {:?}", mel.dims());
-
- let segments = self.run(&mel).map_err(w)?;
- Ok(segments)
- }
-}
-
async fn fetch_url(url: &str) -> Result<Vec<u8>, JsValue> {
use web_sys::{Request, RequestCache, RequestInit, RequestMode, Response};
let window = web_sys::window().ok_or("window")?;
@@ -307,47 +33,61 @@ async fn fetch_url(url: &str) -> Result<Vec<u8>, JsValue> {
Ok(data)
}
-fn w<T: ToString>(x: T) -> String {
- x.to_string()
-}
-
pub enum Msg {
Run(usize),
UpdateStatus(String),
- RunFinished(String),
- SetDecoder(Decoder),
+ SetDecoder(ModelData),
+ WorkerInMsg(WorkerInput),
+ WorkerOutMsg(WorkerOutput),
}
pub struct App {
status: String,
content: String,
decode_in_flight: bool,
- decoder: Option<std::sync::Arc<Decoder>>,
+ worker: Box<dyn Bridge<Worker>>,
+}
+
+async fn model_data_load() -> Result<ModelData, JsValue> {
+ let tokenizer = fetch_url("tokenizer.en.json").await?;
+ let mel_filters = fetch_url("mel_filters.safetensors").await?;
+ let weights = fetch_url("tiny.en.safetensors").await?;
+ console_log!("{}", weights.len());
+ Ok(ModelData {
+ tokenizer,
+ mel_filters,
+ weights,
+ })
}
impl Component for App {
type Message = Msg;
type Properties = ();
- fn create(_ctx: &Context<Self>) -> Self {
+ fn create(ctx: &Context<Self>) -> Self {
let status = "loading weights".to_string();
+ let cb = {
+ let link = ctx.link().clone();
+ move |e| link.send_message(Self::Message::WorkerOutMsg(e))
+ };
+ let worker = Worker::bridge(std::rc::Rc::new(cb));
Self {
status,
content: String::new(),
decode_in_flight: false,
- decoder: None,
+ worker,
}
}
fn rendered(&mut self, ctx: &Context<Self>, first_render: bool) {
if first_render {
ctx.link().send_future(async {
- match Decoder::load().await {
+ match model_data_load().await {
Err(err) => {
let status = format!("{err:?}");
Msg::UpdateStatus(status)
}
- Ok(decoder) => Msg::SetDecoder(decoder),
+ Ok(model_data) => Msg::SetDecoder(model_data),
}
});
}
@@ -355,43 +95,46 @@ impl Component for App {
fn update(&mut self, ctx: &Context<Self>, msg: Self::Message) -> bool {
match msg {
- Msg::SetDecoder(decoder) => {
+ Msg::SetDecoder(md) => {
self.status = "weights loaded succesfully!".to_string();
- self.decoder = Some(std::sync::Arc::new(decoder));
+ console_log!("loaded weights");
+ self.worker.send(WorkerInput::ModelData(md));
true
}
Msg::Run(sample_index) => {
let sample = SAMPLE_NAMES[sample_index];
- match &self.decoder {
- None => self.content = "waiting for weights to load".to_string(),
- Some(decoder) => {
- if self.decode_in_flight {
- self.content = "already decoding some sample at the moment".to_string()
- } else {
- let decoder = decoder.clone();
- self.decode_in_flight = true;
- self.status = format!("decoding {sample}");
- self.content = String::new();
- ctx.link().send_future(async move {
- let content = decoder.load_and_run(sample).await;
- let content = match content {
- Err(err) => format!("decoding error: {err:?}"),
- Ok(segments) => format!("decoded succesfully: {segments:?}"),
- };
- Msg::RunFinished(content)
- })
+ if self.decode_in_flight {
+ self.content = "already decoding some sample at the moment".to_string()
+ } else {
+ self.decode_in_flight = true;
+ self.status = format!("decoding {sample}");
+ self.content = String::new();
+ ctx.link().send_future(async move {
+ match fetch_url(sample).await {
+ Err(err) => {
+ let value = Err(format!("decoding error: {err:?}"));
+ // Mimic a worker output to so as to release decode_in_flight
+ Msg::WorkerOutMsg(WorkerOutput { value })
+ }
+ Ok(wav_bytes) => {
+ Msg::WorkerInMsg(WorkerInput::DecodeTask { wav_bytes })
+ }
}
- //
- }
+ })
}
+ //
true
}
- Msg::RunFinished(content) => {
- self.status = "Run finished!".to_string();
- self.content = content;
+ Msg::WorkerOutMsg(WorkerOutput { value }) => {
+ self.status = "Worker responded!".to_string();
+ self.content = format!("{value:?}");
self.decode_in_flight = false;
true
}
+ Msg::WorkerInMsg(inp) => {
+ self.worker.send(inp);
+ true
+ }
Msg::UpdateStatus(status) => {
self.status = status;
true
diff --git a/candle-wasm-example/src/audio.rs b/candle-wasm-example/src/audio.rs
index d73c3142..5b414368 100644
--- a/candle-wasm-example/src/audio.rs
+++ b/candle-wasm-example/src/audio.rs
@@ -1,6 +1,6 @@
// Audio processing code, adapted from whisper.cpp
// https://github.com/ggerganov/whisper.cpp
-use super::app;
+use super::worker;
pub trait Float: num_traits::Float + num_traits::FloatConst + num_traits::NumAssign {}
@@ -170,7 +170,7 @@ fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
let n_len = samples.len() / fft_step;
// pad audio with at least one extra chunk of zeros
- let pad = 100 * app::CHUNK_LENGTH / 2;
+ let pad = 100 * worker::CHUNK_LENGTH / 2;
let n_len = if n_len % pad != 0 {
(n_len / pad + 1) * pad
} else {
@@ -208,9 +208,9 @@ pub fn pcm_to_mel<T: Float + std::fmt::Display>(
let mel = log_mel_spectrogram_(
samples,
filters,
- app::N_FFT,
- app::HOP_LENGTH,
- app::N_MELS,
+ worker::N_FFT,
+ worker::HOP_LENGTH,
+ worker::N_MELS,
false,
);
Ok(mel)
diff --git a/candle-wasm-example/src/bin/app.rs b/candle-wasm-example/src/bin/app.rs
new file mode 100644
index 00000000..47cd450b
--- /dev/null
+++ b/candle-wasm-example/src/bin/app.rs
@@ -0,0 +1,4 @@
+fn main() {
+ wasm_logger::init(wasm_logger::Config::new(log::Level::Trace));
+ yew::Renderer::<candle_wasm_example::App>::new().render();
+}
diff --git a/candle-wasm-example/src/bin/worker.rs b/candle-wasm-example/src/bin/worker.rs
new file mode 100644
index 00000000..9d74bfda
--- /dev/null
+++ b/candle-wasm-example/src/bin/worker.rs
@@ -0,0 +1,4 @@
+use yew_agent::PublicWorker;
+fn main() {
+ candle_wasm_example::Worker::register();
+}
diff --git a/candle-wasm-example/src/lib.rs b/candle-wasm-example/src/lib.rs
index c4c0a3cf..54c2367c 100644
--- a/candle-wasm-example/src/lib.rs
+++ b/candle-wasm-example/src/lib.rs
@@ -1,14 +1,8 @@
#![allow(dead_code)]
-use wasm_bindgen::prelude::*;
mod app;
mod audio;
mod model;
-
-#[wasm_bindgen]
-pub fn run_app() -> Result<(), JsValue> {
- wasm_logger::init(wasm_logger::Config::new(log::Level::Trace));
- yew::Renderer::<app::App>::new().render();
-
- Ok(())
-}
+mod worker;
+pub use app::App;
+pub use worker::Worker;
diff --git a/candle-wasm-example/src/worker.rs b/candle-wasm-example/src/worker.rs
new file mode 100644
index 00000000..c1074ecd
--- /dev/null
+++ b/candle-wasm-example/src/worker.rs
@@ -0,0 +1,339 @@
+use crate::model::{Config, Whisper};
+use anyhow::Error as E;
+use candle::{DType, Device, Tensor};
+use candle_nn::VarBuilder;
+use rand::distributions::Distribution;
+use serde::{Deserialize, Serialize};
+use tokenizers::Tokenizer;
+use wasm_bindgen::prelude::*;
+use yew_agent::{HandlerId, Public, WorkerLink};
+
+#[wasm_bindgen]
+extern "C" {
+ // Use `js_namespace` here to bind `console.log(..)` instead of just
+ // `log(..)`
+ #[wasm_bindgen(js_namespace = console)]
+ pub fn log(s: &str);
+}
+
+#[macro_export]
+macro_rules! console_log {
+ // Note that this is using the `log` function imported above during
+ // `bare_bones`
+ ($($t:tt)*) => ($crate::worker::log(&format_args!($($t)*).to_string()))
+}
+
+pub const DTYPE: DType = DType::F32;
+
+// Audio parameters.
+pub const SAMPLE_RATE: usize = 16000;
+pub const N_FFT: usize = 400;
+pub const N_MELS: usize = 80;
+pub const HOP_LENGTH: usize = 160;
+pub const CHUNK_LENGTH: usize = 30;
+pub const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
+pub const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input
+pub const N_SAMPLES_PER_TOKEN: usize = HOP_LENGTH * 2; // the initial convolutions has stride 2
+pub const FRAMES_PER_SECOND: usize = SAMPLE_RATE / HOP_LENGTH; // 10ms per audio frame
+pub const TOKENS_PER_SECOND: usize = SAMPLE_RATE / N_SAMPLES_PER_TOKEN; // 20ms per audio token
+
+pub const NO_SPEECH_THRESHOLD: f64 = 0.6;
+pub const LOGPROB_THRESHOLD: f64 = -1.0;
+pub const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
+pub const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
+
+// Tokenizer dependent bits.
+pub const SOT_TOKEN: u32 = 50257;
+pub const EOT_TOKEN: u32 = 50256;
+pub const NO_SPEECH_TOKEN: u32 = 50361;
+pub const NO_TIMESTAMP_TOKEN: u32 = 50362;
+// From the _get_suppress_tokens function + 50362 (no timestamp)
+// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605
+pub const SUPPRESS_TOKENS: [u32; 91] = [
+ 1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357,
+ 366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391, 1635, 1782,
+ 1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959,
+ 10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, 16410, 16791, 17992,
+ 19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409, 34949, 40283, 40493, 40549,
+ 47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361, 50362,
+];
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+struct DecodingResult {
+ tokens: Vec<u32>,
+ text: String,
+ avg_logprob: f64,
+ no_speech_prob: f64,
+ temperature: f64,
+ compression_ratio: f64,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct Segment {
+ start: f64,
+ duration: f64,
+ dr: DecodingResult,
+}
+
+pub struct Decoder {
+ model: Whisper,
+ mel_filters: Vec<f32>,
+ tokenizer: Tokenizer,
+ suppress_tokens: Tensor,
+}
+
+impl Decoder {
+ fn new(
+ model: Whisper,
+ tokenizer: Tokenizer,
+ mel_filters: Vec<f32>,
+ device: &Device,
+ ) -> anyhow::Result<Self> {
+ let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
+ .map(|i| {
+ if SUPPRESS_TOKENS.contains(&i) {
+ f32::NEG_INFINITY
+ } else {
+ 0f32
+ }
+ })
+ .collect();
+ let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
+ Ok(Self {
+ model,
+ mel_filters,
+ tokenizer,
+ suppress_tokens,
+ })
+ }
+
+ fn decode(&self, mel: &Tensor, t: f64) -> anyhow::Result<DecodingResult> {
+ let model = &self.model;
+ let audio_features = model.encoder.forward(mel)?;
+ console_log!("audio features: {:?}", audio_features.dims());
+ let sample_len = model.config.max_target_positions / 2;
+ let mut sum_logprob = 0f64;
+ let mut no_speech_prob = f64::NAN;
+ let mut tokens = vec![SOT_TOKEN];
+ for i in 0..sample_len {
+ let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
+
+ // The model expects a batch dim but this inference loop does not handle
+ // it so we add it at this point.
+ let tokens_t = tokens_t.unsqueeze(0)?;
+ let logits = model.decoder.forward(&tokens_t, &audio_features)?;
+ let logits = logits.squeeze(0)?;
+
+ // Extract the no speech probability on the first iteration by looking at the first
+ // token logits and the probability for the according token.
+ if i == 0 {
+ no_speech_prob = logits
+ .get(0)?
+ .softmax(0)?
+ .get(NO_SPEECH_TOKEN as usize)?
+ .to_scalar::<f32>()? as f64;
+ }
+
+ let (seq_len, _) = logits.shape().r2()?;
+ let logits = logits
+ .get(seq_len - 1)?
+ .broadcast_add(&self.suppress_tokens)?;
+ let next_token = if t > 0f64 {
+ let prs = (&logits / t)?.softmax(0)?;
+ let logits_v: Vec<f32> = prs.to_vec1()?;
+ let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
+ let mut rng = rand::thread_rng();
+ distr.sample(&mut rng) as u32
+ } else {
+ let logits_v: Vec<f32> = logits.to_vec1()?;
+ logits_v
+ .iter()
+ .enumerate()
+ .max_by(|(_, u), (_, v)| u.total_cmp(v))
+ .map(|(i, _)| i as u32)
+ .unwrap()
+ };
+ tokens.push(next_token);
+ let prob = logits
+ .softmax(candle::D::Minus1)?
+ .get(next_token as usize)?
+ .to_scalar::<f32>()? as f64;
+ if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions {
+ break;
+ }
+ sum_logprob += prob.ln();
+ }
+ let text = self
+ .tokenizer
+ .decode(tokens.clone(), true)
+ .map_err(E::msg)?;
+ let avg_logprob = sum_logprob / tokens.len() as f64;
+
+ Ok(DecodingResult {
+ tokens,
+ text,
+ avg_logprob,
+ no_speech_prob,
+ temperature: t,
+ compression_ratio: f64::NAN,
+ })
+ }
+
+ fn decode_with_fallback(&self, segment: &Tensor) -> anyhow::Result<DecodingResult> {
+ for (i, &t) in TEMPERATURES.iter().enumerate() {
+ let dr: Result<DecodingResult, _> = self.decode(segment, t);
+ if i == TEMPERATURES.len() - 1 {
+ return dr;
+ }
+ // On errors, we try again with a different temperature.
+ match dr {
+ Ok(dr) => {
+ let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD
+ || dr.avg_logprob < LOGPROB_THRESHOLD;
+ if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD {
+ return Ok(dr);
+ }
+ }
+ Err(err) => {
+ console_log!("Error running at {t}: {err}")
+ }
+ }
+ }
+ unreachable!()
+ }
+
+ fn run(&self, mel: &Tensor) -> anyhow::Result<Vec<Segment>> {
+ let (_, _, content_frames) = mel.shape().r3()?;
+ let mut seek = 0;
+ let mut segments = vec![];
+ while seek < content_frames {
+ let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
+ let segment_size = usize::min(content_frames - seek, N_FRAMES);
+ let mel_segment = mel.narrow(2, seek, segment_size)?;
+ let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
+ let dr = self.decode_with_fallback(&mel_segment)?;
+ seek += segment_size;
+ if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD {
+ console_log!("no speech detected, skipping {seek} {dr:?}");
+ continue;
+ }
+ let segment = Segment {
+ start: time_offset,
+ duration: segment_duration,
+ dr,
+ };
+ console_log!("{seek}: {segment:?}");
+ segments.push(segment)
+ }
+ Ok(segments)
+ }
+
+ fn load(md: ModelData) -> anyhow::Result<Self> {
+ let device = Device::Cpu;
+ let tokenizer = Tokenizer::from_bytes(&md.tokenizer).map_err(anyhow::Error::msg)?;
+
+ let mel_filters = candle::safetensors::SafeTensors::from_buffer(&md.mel_filters)?;
+ let mel_filters = mel_filters.tensor("mel_80", &device)?;
+ console_log!("loaded mel filters {:?}", mel_filters.shape());
+ let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
+ let weights = candle::safetensors::SafeTensors::from_buffer(&md.weights)?;
+ let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
+ let config = Config::tiny_en();
+ let whisper = Whisper::load(&vb, config)?;
+ console_log!("done loading model");
+ let decoder = Self::new(whisper, tokenizer, mel_filters, &device)?;
+ Ok(decoder)
+ }
+
+ fn convert_and_run(&self, wav_input: &[u8]) -> anyhow::Result<Vec<Segment>> {
+ let device = Device::Cpu;
+ let mut wav_input = std::io::Cursor::new(wav_input);
+ let (header, data) = wav::read(&mut wav_input)?;
+ console_log!("loaded wav data: {header:?}");
+ if header.sampling_rate != SAMPLE_RATE as u32 {
+ anyhow::bail!("wav file must have a {SAMPLE_RATE} sampling rate");
+ }
+ let data = data.as_sixteen().expect("expected 16 bit wav file");
+ let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
+ .iter()
+ .map(|v| *v as f32 / 32768.)
+ .collect();
+ console_log!("pcm data loaded {}", pcm_data.len());
+ let mel = crate::audio::pcm_to_mel(&pcm_data, &self.mel_filters)?;
+ let mel_len = mel.len();
+ let mel = Tensor::from_vec(mel, (1, N_MELS, mel_len / N_MELS), &device)?;
+ console_log!("loaded mel: {:?}", mel.dims());
+ let segments = self.run(&mel)?;
+ Ok(segments)
+ }
+}
+
+// Communication to the worker happens through bincode, the model weights and configs are fetched
+// on the main thread and transfered via the following structure.
+#[derive(Serialize, Deserialize)]
+pub struct ModelData {
+ pub tokenizer: Vec<u8>,
+ pub mel_filters: Vec<u8>,
+ pub weights: Vec<u8>,
+}
+
+pub struct Worker {
+ link: WorkerLink<Self>,
+ decoder: Option<Decoder>,
+}
+
+#[derive(Serialize, Deserialize)]
+pub enum WorkerInput {
+ ModelData(ModelData),
+ DecodeTask { wav_bytes: Vec<u8> },
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct WorkerOutput {
+ pub value: Result<Vec<Segment>, String>,
+}
+
+impl yew_agent::Worker for Worker {
+ type Input = WorkerInput;
+ type Message = ();
+ type Output = WorkerOutput;
+ type Reach = Public<Self>;
+
+ fn create(link: WorkerLink<Self>) -> Self {
+ Self {
+ link,
+ decoder: None,
+ }
+ }
+
+ fn update(&mut self, _msg: Self::Message) {
+ // no messaging
+ }
+
+ fn handle_input(&mut self, msg: Self::Input, id: HandlerId) {
+ let value = match msg {
+ WorkerInput::ModelData(md) => match Decoder::load(md) {
+ Ok(decoder) => {
+ self.decoder = Some(decoder);
+ Ok(vec![])
+ }
+ Err(err) => Err(format!("model creation error {err:?}")),
+ },
+ WorkerInput::DecodeTask { wav_bytes } => match &self.decoder {
+ None => Err("model has not been set".to_string()),
+ Some(decoder) => decoder
+ .convert_and_run(&wav_bytes)
+ .map_err(|e| e.to_string()),
+ },
+ };
+ self.link.respond(id, WorkerOutput { value });
+ }
+
+ fn name_of_resource() -> &'static str {
+ "worker.js"
+ }
+
+ fn resource_path_is_relative() -> bool {
+ true
+ }
+}