summaryrefslogtreecommitdiff
path: root/candle-wasm-examples/whisper
diff options
context:
space:
mode:
authorRadamés Ajna <radamajna@gmail.com>2023-08-30 11:35:41 -0700
committerGitHub <noreply@github.com>2023-08-30 20:35:41 +0200
commit1d0bb48fae08f9bb5b6547ccff086c24b87a6775 (patch)
treedc6e66a60abe9eed85370ad30cf0474dcfa6d44c /candle-wasm-examples/whisper
parent21e1c738928eb6ad0266d63ae10f9d8d849bb124 (diff)
downloadcandle-1d0bb48fae08f9bb5b6547ccff086c24b87a6775.tar.gz
candle-1d0bb48fae08f9bb5b6547ccff086c24b87a6775.tar.bz2
candle-1d0bb48fae08f9bb5b6547ccff086c24b87a6775.zip
Improve Whisper WASM UI example (#669)
* wip add module and js worker example * params * clean up, send error * final UI with whisper webworker * add simple instructions
Diffstat (limited to 'candle-wasm-examples/whisper')
-rw-r--r--candle-wasm-examples/whisper/README.md56
-rw-r--r--candle-wasm-examples/whisper/build-lib.sh2
-rw-r--r--candle-wasm-examples/whisper/lib-example.html313
-rw-r--r--candle-wasm-examples/whisper/src/bin/m.rs41
-rw-r--r--candle-wasm-examples/whisper/src/lib.rs2
-rw-r--r--candle-wasm-examples/whisper/src/worker.rs4
-rw-r--r--candle-wasm-examples/whisper/whisperWorker.js72
7 files changed, 487 insertions, 3 deletions
diff --git a/candle-wasm-examples/whisper/README.md b/candle-wasm-examples/whisper/README.md
new file mode 100644
index 00000000..b847a965
--- /dev/null
+++ b/candle-wasm-examples/whisper/README.md
@@ -0,0 +1,56 @@
+## Running Whisper Examples
+
+Here, we provide two examples of how to run Whisper using a Candle-compiled WASM binary and runtimes.
+
+### Pure Rust UI
+
+To build and test the UI made in Rust you will need [Trunk](https://trunkrs.dev/#install)
+From the `candle-wasm-examples/whisper` directory run:
+
+Download assets:
+
+```bash
+# Model and tokenizer
+wget -c https://huggingface.co/spaces/lmz/candle-whisper/resolve/main/mel_filters.safetensors
+wget -c https://huggingface.co/spaces/lmz/candle-whisper/resolve/main/tiny.en.safetensors
+wget -c https://huggingface.co/spaces/lmz/candle-whisper/resolve/main/tokenizer.en.json
+
+
+# Audio samples
+wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_gb0.wav -O gb0.wav
+wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_a13.wav -O a13.wav
+wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_gb1.wav -O gb1.wav
+wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_hp0.wav -O hp0.wav
+wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_jfk.wav -O jfk.wav
+wget -c https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_mm0.wav -O mm0.wav
+
+```
+
+Run hot reload server:
+
+```bash
+trunk serve --release --public-url / --port 8080
+```
+
+### Vanilla JS and WebWorkers
+
+To build and test the UI made in Vanilla JS and WebWorkers, first we need to build the WASM library:
+
+```bash
+sh build-lib.sh
+```
+
+This will bundle the library under `./build` and we can import it inside our WebWorker like a normal JS module:
+
+```js
+import init, { Decoder } from "./build/m.js";
+```
+
+The full example can be found under `./lib-example.html`. All needed assets are fetched from the web, so no need to download anything.
+Finally, you can preview the example by running a local HTTP server. For example:
+
+```bash
+python -m http.server
+```
+
+Then open `http://localhost:8000/lib-example.html` in your browser.
diff --git a/candle-wasm-examples/whisper/build-lib.sh b/candle-wasm-examples/whisper/build-lib.sh
new file mode 100644
index 00000000..b0ebb182
--- /dev/null
+++ b/candle-wasm-examples/whisper/build-lib.sh
@@ -0,0 +1,2 @@
+cargo build --target wasm32-unknown-unknown --release
+wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web
diff --git a/candle-wasm-examples/whisper/lib-example.html b/candle-wasm-examples/whisper/lib-example.html
new file mode 100644
index 00000000..a8c49785
--- /dev/null
+++ b/candle-wasm-examples/whisper/lib-example.html
@@ -0,0 +1,313 @@
+<html>
+ <head>
+ <meta content="text/html;charset=utf-8" http-equiv="Content-Type" />
+ <title>Candle Whisper Rust/WASM</title>
+ </head>
+ <body></body>
+</html>
+
+<!doctype html>
+<html>
+ <head>
+ <meta charset="UTF-8" />
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+ <style>
+ @import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap");
+ html,
+ body {
+ font-family: "Source Sans 3", sans-serif;
+ }
+ </style>
+ <script src="https://cdn.tailwindcss.com"></script>
+ <script type="module">
+ // base url for audio examples
+ const AUDIO_BASE_URL =
+ "https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/";
+
+ // models base url
+ const MODELS = {
+ tiny_en: {
+ base_url:
+ "https://huggingface.co/openai/whisper-tiny.en/resolve/refs%2Fpr%2F17/",
+ },
+ };
+ const whisperWorker = new Worker("./whisperWorker.js", {
+ type: "module",
+ });
+
+ async function classifyAudio(
+ weightsURL, // URL to the weights file
+ modelID, // model ID
+ tokenizerURL, // URL to the tokenizer file
+ mel_filtersURL, // URL to the mel filters file
+ audioURL, // URL to the audio file
+ updateStatus // function to update the status
+ ) {
+ return new Promise((resolve, reject) => {
+ whisperWorker.postMessage({
+ weightsURL,
+ modelID,
+ tokenizerURL,
+ mel_filtersURL,
+ audioURL,
+ });
+ whisperWorker.addEventListener("message", (event) => {
+ console.log(event.data);
+ if ("status" in event.data) {
+ updateStatus(event.data);
+ }
+ if ("error" in event.data) {
+ reject(new Error(event.data.error));
+ }
+ if (event.data.status === "complete") {
+ resolve(event.data);
+ }
+ });
+ });
+ }
+
+ // keep track of the audio URL
+ let audioURL = null;
+ function setAudio(src) {
+ const audio = document.querySelector("#audio");
+ audio.src = src;
+ audio.controls = true;
+ audio.hidden = false;
+ document.querySelector("#detect").disabled = false;
+ audioURL = src;
+ }
+ // add event listener to audio buttons
+ document.querySelectorAll("#audios-select > button").forEach((target) => {
+ target.addEventListener("click", (e) => {
+ const value = target.dataset.value;
+ const href = AUDIO_BASE_URL + value;
+ setAudio(href);
+ });
+ });
+ //add event listener to file input
+ document.querySelector("#file-upload").addEventListener("change", (e) => {
+ const target = e.target;
+ if (target.files.length > 0) {
+ const href = URL.createObjectURL(target.files[0]);
+ setAudio(href);
+ }
+ });
+ // add event listener to drop-area
+ const dropArea = document.querySelector("#drop-area");
+ dropArea.addEventListener("dragenter", (e) => {
+ e.preventDefault();
+ dropArea.classList.add("border-blue-700");
+ });
+ dropArea.addEventListener("dragleave", (e) => {
+ e.preventDefault();
+ dropArea.classList.remove("border-blue-700");
+ });
+ dropArea.addEventListener("dragover", (e) => {
+ e.preventDefault();
+ dropArea.classList.add("border-blue-700");
+ });
+ dropArea.addEventListener("drop", (e) => {
+ e.preventDefault();
+ dropArea.classList.remove("border-blue-700");
+ const url = e.dataTransfer.getData("text/uri-list");
+ const files = e.dataTransfer.files;
+ if (files.length > 0) {
+ const href = URL.createObjectURL(files[0]);
+ setAudio(href);
+ } else if (url) {
+ setAudio(url);
+ }
+ });
+
+ // add event listener to detect button
+ document.querySelector("#detect").addEventListener("click", async () => {
+ if (audioURL === null) {
+ return;
+ }
+ const modelID = document.querySelector("#model").value;
+ const modelURL = MODELS[modelID].base_url + "model.safetensors";
+ const tokenizerURL = MODELS[modelID].base_url + "tokenizer.json";
+
+ classifyAudio(
+ modelURL,
+ modelID,
+ tokenizerURL,
+ "mel_filters.safetensors",
+ audioURL,
+ updateStatus
+ )
+ .then((result) => {
+ console.log("RESULT", result);
+ const { output } = result;
+ const text = output.map((segment) => segment.dr.text).join(" ");
+ console.log(text);
+ document.getElementById("output").textContent = text;
+ })
+ .catch((error) => {
+ console.error(error);
+ });
+ });
+
+ function updateStatus(data) {
+ const { status, message } = data;
+ const button = document.querySelector("#detect");
+ if (status === "decoding" || status === "loading") {
+ button.disabled = true;
+ button.textContent = message;
+ } else if (status === "complete") {
+ button.disabled = false;
+ button.textContent = "Transcribe Audio";
+ }
+ }
+ </script>
+ </head>
+ <body class="container max-w-4xl mx-auto p-4">
+ <main class="grid grid-cols-1 gap-8 relative">
+ <span class="absolute text-5xl -ml-[1em]"> 🕯️ </span>
+ <div>
+ <h1 class="text-5xl font-bold">Candle Whisper</h1>
+ <h2 class="text-2xl font-bold">Rust/WASM Demo</h2>
+ <p class="max-w-lg">
+ Transcribe audio in the browser using rust/wasm with an audio file.
+ This demo uses the
+ <a
+ href="https://huggingface.co/openai/"
+ target="_blank"
+ class="underline hover:text-blue-500 hover:no-underline"
+ >
+ OpenAI Whisper models
+ </a>
+ and WASM runtime built with
+ <a
+ href="https://github.com/huggingface/candle/"
+ target="_blank"
+ class="underline hover:text-blue-500 hover:no-underline"
+ >Candle
+ </a>
+ </p>
+ </div>
+
+ <div>
+ <label for="model" class="font-medium">Models Options: </label>
+ <select
+ id="model"
+ class="border-2 border-gray-500 rounded-md font-light"
+ >
+ <option value="tiny_en" selected>tiny.en (151 MB)</option>
+ </select>
+ </div>
+ <!-- drag and drop area -->
+ <div class="relative">
+ <div
+ id="drop-area"
+ class="flex flex-col items-center justify-center border-2 border-gray-300 border-dashed rounded-xl relative h-48 w-full overflow-hidden"
+ >
+ <div
+ class="flex flex-col items-center justify-center space-y-1 text-center"
+ >
+ <svg
+ width="25"
+ height="25"
+ viewBox="0 0 25 25"
+ fill="none"
+ xmlns="http://www.w3.org/2000/svg"
+ >
+ <path
+ d="M3.5 24.3a3 3 0 0 1-1.9-.8c-.5-.5-.8-1.2-.8-1.9V2.9c0-.7.3-1.3.8-1.9.6-.5 1.2-.7 2-.7h18.6c.7 0 1.3.2 1.9.7.5.6.7 1.2.7 2v18.6c0 .7-.2 1.4-.7 1.9a3 3 0 0 1-2 .8H3.6Zm0-2.7h18.7V2.9H3.5v18.7Zm2.7-2.7h13.3c.3 0 .5 0 .6-.3v-.7l-3.7-5a.6.6 0 0 0-.6-.2c-.2 0-.4 0-.5.3l-3.5 4.6-2.4-3.3a.6.6 0 0 0-.6-.3c-.2 0-.4.1-.5.3l-2.7 3.6c-.1.2-.2.4 0 .7.1.2.3.3.6.3Z"
+ fill="#000"
+ />
+ </svg>
+ <div class="flex text-sm text-gray-600">
+ <label
+ for="file-upload"
+ class="relative cursor-pointer bg-white rounded-md font-medium text-blue-950 hover:text-blue-700"
+ >
+ <span>Drag and drop your audio here</span>
+ <span class="block text-xs">or</span>
+ <span class="block text-xs">Click to upload</span>
+ </label>
+ </div>
+ <input
+ id="file-upload"
+ name="file-upload"
+ type="file"
+ accept="audio/*"
+ class="sr-only"
+ />
+ </div>
+ <audio
+ id="audio"
+ hidden
+ controls
+ class="w-full p-2 select-none"
+ ></audio>
+ </div>
+ </div>
+ <div>
+ <div class="flex flex-wrap gap-3 items-center" id="audios-select">
+ <h3 class="font-medium">Examples:</h3>
+ <button
+ data-value="samples_jfk.wav"
+ class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline"
+ >
+ <span>jfk.wav</span>
+ <span class="text-xs block"> (352 kB)</span>
+ </button>
+ <button
+ data-value="samples_a13.wav"
+ class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline"
+ >
+ <span>a13.wav</span>
+ <span class="text-xs block"> (960 kB)</span>
+ </button>
+ <button
+ data-value="samples_mm0.wav"
+ class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline"
+ >
+ <span>mm0.wav</span>
+ <span class="text-xs block new"> (957 kB)</span>
+ </button>
+ <button
+ data-value="samples_gb0.wav"
+ class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline"
+ >
+ <span>gb0.wav </span>
+ <span class="text-xs block">(4.08 MB)</span>
+ </button>
+ <button
+ data-value="samples_gb1.wav"
+ class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline"
+ >
+ <span>gb1.wav </span>
+ <span class="text-xs block">(6.36 MB)</span>
+ </button>
+ <button
+ data-value="samples_hp0.wav"
+ class="text-gray-500 border border-gray-500 rounded-md p-2 underline hover:no-underline"
+ >
+ <span>hp0.wav </span>
+ <span class="text-xs block">(8.75 MB)</span>
+ </button>
+ </div>
+ </div>
+
+ <div>
+ <button
+ id="detect"
+ disabled
+ class="bg-orange-900 hover:bg-orange-800 text-white font-normal py-2 px-4 rounded disabled:opacity-75 disabled:cursor-not-allowed"
+ >
+ Transcribe Audio
+ </button>
+ </div>
+ <div>
+ <h3 class="font-medium">Transcription:</h3>
+
+ <div
+ id="output"
+ class="min-h-[100px] bg-slate-500 text-white p-4 rounded-md"
+ ></div>
+ </div>
+ </main>
+ </body>
+</html>
diff --git a/candle-wasm-examples/whisper/src/bin/m.rs b/candle-wasm-examples/whisper/src/bin/m.rs
new file mode 100644
index 00000000..88b25267
--- /dev/null
+++ b/candle-wasm-examples/whisper/src/bin/m.rs
@@ -0,0 +1,41 @@
+use candle_wasm_example_whisper::worker::{Decoder as D, ModelData};
+use wasm_bindgen::prelude::*;
+
+#[wasm_bindgen]
+pub struct Decoder {
+ decoder: D,
+}
+
+#[wasm_bindgen]
+impl Decoder {
+ #[wasm_bindgen(constructor)]
+ pub fn new(
+ weights: Vec<u8>,
+ tokenizer: Vec<u8>,
+ mel_filters: Vec<u8>,
+ ) -> Result<Decoder, JsError> {
+ let decoder = D::load(ModelData {
+ tokenizer,
+ mel_filters,
+ weights,
+ });
+
+ match decoder {
+ Ok(decoder) => Ok(Self { decoder }),
+ Err(e) => Err(JsError::new(&e.to_string())),
+ }
+ }
+
+ #[wasm_bindgen]
+ pub fn decode(&self, wav_input: Vec<u8>) -> Result<String, JsError> {
+ let segments = self
+ .decoder
+ .convert_and_run(&wav_input)
+ .map_err(|e| JsError::new(&e.to_string()))?;
+
+ let json = serde_json::to_string(&segments)?;
+ Ok(json)
+ }
+}
+
+fn main() {}
diff --git a/candle-wasm-examples/whisper/src/lib.rs b/candle-wasm-examples/whisper/src/lib.rs
index d738ca6a..141714f5 100644
--- a/candle-wasm-examples/whisper/src/lib.rs
+++ b/candle-wasm-examples/whisper/src/lib.rs
@@ -24,6 +24,6 @@ impl Drop for Timer {
mod app;
mod audio;
mod model;
-mod worker;
+pub mod worker;
pub use app::App;
pub use worker::Worker;
diff --git a/candle-wasm-examples/whisper/src/worker.rs b/candle-wasm-examples/whisper/src/worker.rs
index bbcae36c..49b2cd09 100644
--- a/candle-wasm-examples/whisper/src/worker.rs
+++ b/candle-wasm-examples/whisper/src/worker.rs
@@ -222,7 +222,7 @@ impl Decoder {
Ok(segments)
}
- fn load(md: ModelData) -> anyhow::Result<Self> {
+ pub fn load(md: ModelData) -> anyhow::Result<Self> {
let device = Device::Cpu;
let tokenizer = Tokenizer::from_bytes(&md.tokenizer).map_err(anyhow::Error::msg)?;
@@ -239,7 +239,7 @@ impl Decoder {
Ok(decoder)
}
- fn convert_and_run(&self, wav_input: &[u8]) -> anyhow::Result<Vec<Segment>> {
+ pub fn convert_and_run(&self, wav_input: &[u8]) -> anyhow::Result<Vec<Segment>> {
let device = Device::Cpu;
let mut wav_input = std::io::Cursor::new(wav_input);
let (header, data) = wav::read(&mut wav_input)?;
diff --git a/candle-wasm-examples/whisper/whisperWorker.js b/candle-wasm-examples/whisper/whisperWorker.js
new file mode 100644
index 00000000..2598adde
--- /dev/null
+++ b/candle-wasm-examples/whisper/whisperWorker.js
@@ -0,0 +1,72 @@
+//load the candle Whisper decoder wasm module
+import init, { Decoder } from "./build/m.js";
+
+async function fetchArrayBuffer(url) {
+ const res = await fetch(url, {
+ cache: "force-cache",
+ headers: {
+ "Cache-Control": "public, max-age=31536000",
+ },
+ });
+ const data = await res.arrayBuffer();
+ return new Uint8Array(data);
+}
+
+class Whisper {
+ static instance = {};
+ // Retrieve the Whisper model. When called for the first time,
+ // this will load the model and save it for future use.
+ static async getInstance(weightsURL, modelID, tokenizerURL, mel_filtersURL) {
+ // load individual modelID only once
+ if (!this.instance[modelID]) {
+ await init();
+
+ self.postMessage({ status: "loading", message: "Loading Model" });
+ const [weightsArrayU8, tokenizerArrayU8, mel_filtersArrayU8] =
+ await Promise.all([
+ fetchArrayBuffer(weightsURL),
+ fetchArrayBuffer(tokenizerURL),
+ fetchArrayBuffer(mel_filtersURL),
+ ]);
+
+ this.instance[modelID] = new Decoder(
+ weightsArrayU8,
+ tokenizerArrayU8,
+ mel_filtersArrayU8
+ );
+ } else {
+ self.postMessage({ status: "loading", message: "Model Already Loaded" });
+ }
+ return this.instance[modelID];
+ }
+}
+
+self.addEventListener("message", async (event) => {
+ const { weightsURL, modelID, tokenizerURL, mel_filtersURL, audioURL } =
+ event.data;
+ try {
+ self.postMessage({ status: "decoding", message: "Starting Decoder" });
+
+ const decoder = await Whisper.getInstance(
+ weightsURL,
+ modelID,
+ tokenizerURL,
+ mel_filtersURL
+ );
+
+ self.postMessage({ status: "decoding", message: "Loading Audio" });
+ const audioArrayU8 = await fetchArrayBuffer(audioURL);
+
+ self.postMessage({ status: "decoding", message: "Running Decoder..." });
+ const segments = decoder.decode(audioArrayU8);
+
+ // Send the segment back to the main thread as JSON
+ self.postMessage({
+ status: "complete",
+ message: "complete",
+ output: JSON.parse(segments),
+ });
+ } catch (e) {
+ self.postMessage({ error: e });
+ }
+});