summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRadamés Ajna <radamajna@gmail.com>2023-09-22 07:31:10 -0700
committerGitHub <noreply@github.com>2023-09-22 15:31:10 +0100
commit19e52e5007e10816eefb2e1a1968be760c5d11a4 (patch)
treef0d9cad35d261c3a28d2c4fa5a8ff1af84a4631f
parent8601537e31af610c0bbd32ee8c8ee17ed802427c (diff)
downloadcandle-19e52e5007e10816eefb2e1a1968be760c5d11a4.tar.gz
candle-19e52e5007e10816eefb2e1a1968be760c5d11a4.tar.bz2
candle-19e52e5007e10816eefb2e1a1968be760c5d11a4.zip
T5 Wasm (#918)
* init t5 wasm model * split workers for each model * clean up * add some ui * readme * index * typo * remove cache param, clear_kv_cache * add max_length as param * add model tasks option to ui * add method to load quantized gguf from buffer * Add quantized wasm module * add quantized models to UI, dynamic import wasms * link to quantized * fix copy * fix ModelEncoder * fix README.md
-rw-r--r--Cargo.toml1
-rw-r--r--candle-transformers/src/models/quantized_t5.rs15
-rw-r--r--candle-wasm-examples/t5/Cargo.toml33
-rw-r--r--candle-wasm-examples/t5/README.md32
-rw-r--r--candle-wasm-examples/t5/T5ModelConditionalGeneration.js93
-rw-r--r--candle-wasm-examples/t5/T5ModelEncoderWorker.js83
-rw-r--r--candle-wasm-examples/t5/build-lib.sh3
-rw-r--r--candle-wasm-examples/t5/index.html276
-rw-r--r--candle-wasm-examples/t5/src/bin/m-quantized.rs205
-rw-r--r--candle-wasm-examples/t5/src/bin/m.rs206
-rw-r--r--candle-wasm-examples/t5/src/lib.rs16
-rw-r--r--candle-wasm-examples/t5/utils.js168
12 files changed, 1131 insertions, 0 deletions
diff --git a/Cargo.toml b/Cargo.toml
index 6cbbf00f..5ae64523 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -12,6 +12,7 @@ members = [
"candle-wasm-examples/whisper",
"candle-wasm-examples/yolo",
"candle-wasm-examples/bert",
+ "candle-wasm-examples/t5",
]
exclude = ["candle-flash-attn", "candle-kernels"]
resolver = "2"
diff --git a/candle-transformers/src/models/quantized_t5.rs b/candle-transformers/src/models/quantized_t5.rs
index a10c3b80..a86dfcb3 100644
--- a/candle-transformers/src/models/quantized_t5.rs
+++ b/candle-transformers/src/models/quantized_t5.rs
@@ -30,6 +30,21 @@ impl VarBuilder {
})
}
+ pub fn from_gguf_buffer(buffer: &[u8]) -> Result<Self> {
+ let mut cursor = std::io::Cursor::new(buffer);
+ let content = candle::quantized::gguf_file::Content::read(&mut cursor)?;
+ let mut data = std::collections::HashMap::new();
+ for tensor_name in content.tensor_infos.keys() {
+ let tensor = content.tensor(&mut cursor, tensor_name)?;
+ data.insert(tensor_name.to_string(), Arc::new(tensor));
+ }
+ Ok(Self {
+ data: Arc::new(data),
+ path: Vec::new(),
+ device: Device::Cpu,
+ })
+ }
+
fn pp<S: ToString>(&self, s: S) -> Self {
let mut path = self.path.clone();
path.push(s.to_string());
diff --git a/candle-wasm-examples/t5/Cargo.toml b/candle-wasm-examples/t5/Cargo.toml
new file mode 100644
index 00000000..b011d2e5
--- /dev/null
+++ b/candle-wasm-examples/t5/Cargo.toml
@@ -0,0 +1,33 @@
+[package]
+name = "candle-wasm-example-t5"
+version.workspace = true
+edition.workspace = true
+description.workspace = true
+repository.workspace = true
+keywords.workspace = true
+categories.workspace = true
+license.workspace = true
+
+[dependencies]
+candle = { path = "../../candle-core", version = "0.2.2", package = "candle-core" }
+candle-nn = { path = "../../candle-nn", version = "0.2.2" }
+candle-transformers = { path = "../../candle-transformers", version = "0.2.2" }
+num-traits = { workspace = true }
+tokenizers = { workspace = true, features = ["unstable_wasm"] }
+
+# App crates.
+anyhow = { workspace = true }
+byteorder = { workspace = true }
+log = { workspace = true }
+rand = { workspace = true }
+serde = { workspace = true }
+serde_json = { workspace = true }
+safetensors = { workspace = true }
+
+# Wasm specific crates.
+console_error_panic_hook = "0.1.7"
+getrandom = { version = "0.2", features = ["js"] }
+gloo = "0.8"
+js-sys = "0.3.64"
+wasm-bindgen = "0.2.87"
+serde-wasm-bindgen = "0.6.0"
diff --git a/candle-wasm-examples/t5/README.md b/candle-wasm-examples/t5/README.md
new file mode 100644
index 00000000..9a9f5bce
--- /dev/null
+++ b/candle-wasm-examples/t5/README.md
@@ -0,0 +1,32 @@
+## Running T5 with Candle and WASM
+
+Here, we provide two examples of how to run Bert using a Candle-compiled WASM binary and runtime.
+
+### Vanilla JS and WebWorkers
+
+To build and test the UI made in Vanilla JS and WebWorkers, first we need to build the WASM library:
+
+```bash
+sh build-lib.sh
+```
+
+This will bundle the library under `./build` and we can import it inside our WebWorker like a normal JS module:
+
+```js
+import init, { ModelConditionalGeneration, ModelEncoder } from "./build/m.js";
+```
+
+For the quantized version, we need to import the quantized module:
+
+```js
+import init, { ModelConditionalGeneration, ModelEncoder } from "./build/m-quantized.js";
+```
+
+The full example can be found under `./index.html`. All needed assets are fetched from the web, so no need to download anything.
+Finally, you can preview the example by running a local HTTP server. For example:
+
+```bash
+python -m http.server
+```
+
+Then open `http://localhost:8000/index.html` in your browser.
diff --git a/candle-wasm-examples/t5/T5ModelConditionalGeneration.js b/candle-wasm-examples/t5/T5ModelConditionalGeneration.js
new file mode 100644
index 00000000..5f94c19a
--- /dev/null
+++ b/candle-wasm-examples/t5/T5ModelConditionalGeneration.js
@@ -0,0 +1,93 @@
+//load Candle Bert Module wasm module
+let init, ModelConditionalGeneration;
+
+async function fetchArrayBuffer(url) {
+ const cacheName = "t5-candle-cache";
+ const cache = await caches.open(cacheName);
+ const cachedResponse = await cache.match(url);
+ if (cachedResponse) {
+ const data = await cachedResponse.arrayBuffer();
+ return new Uint8Array(data);
+ }
+ const res = await fetch(url, { cache: "force-cache" });
+ cache.put(url, res.clone());
+ return new Uint8Array(await res.arrayBuffer());
+}
+class ConditionalGeneration {
+ static instance = {};
+
+ static async getInstance(weightsURL, tokenizerURL, configURL, modelID) {
+ if (modelID.includes("quantized")) {
+ ({ default: init, ModelConditionalGeneration } = await import(
+ "./build/m-quantized.js"
+ ));
+ } else {
+ ({ default: init, ModelConditionalGeneration } = await import(
+ "./build/m.js"
+ ));
+ }
+ if (!this.instance[modelID]) {
+ await init();
+
+ self.postMessage({ status: "loading", message: "Loading Model" });
+ const [weightsArrayU8, tokenizerArrayU8, configArrayU8] =
+ await Promise.all([
+ fetchArrayBuffer(weightsURL),
+ fetchArrayBuffer(tokenizerURL),
+ fetchArrayBuffer(configURL),
+ ]);
+
+ this.instance[modelID] = new ModelConditionalGeneration(
+ weightsArrayU8,
+ tokenizerArrayU8,
+ configArrayU8
+ );
+ } else {
+ self.postMessage({ status: "ready", message: "Model Already Loaded" });
+ }
+ return this.instance[modelID];
+ }
+}
+
+self.addEventListener("message", async (event) => {
+ const { weightsURL, tokenizerURL, configURL, modelID, prompt, params } =
+ event.data;
+ let {
+ temperature = 0.0,
+ seed = 299792458,
+ repeat_penalty = 1.1,
+ repeat_last_n = 64,
+ top_p = 1,
+ } = { ...params };
+ try {
+ self.postMessage({
+ status: "ready",
+ message: "Starting T5 Conditional Generation",
+ });
+ const model = await ConditionalGeneration.getInstance(
+ weightsURL,
+ tokenizerURL,
+ configURL,
+ modelID
+ );
+ self.postMessage({
+ status: "decoding",
+ message: "Decoding Prompt",
+ });
+ const output = model.decode({
+ prompt,
+ temperature,
+ seed,
+ top_p,
+ repeat_penalty,
+ repeat_last_n,
+ });
+ self.postMessage({
+ status: "complete",
+ message: "complete",
+ output: output,
+ });
+ } catch (e) {
+ self.postMessage({ error: e });
+ }
+});
diff --git a/candle-wasm-examples/t5/T5ModelEncoderWorker.js b/candle-wasm-examples/t5/T5ModelEncoderWorker.js
new file mode 100644
index 00000000..a83b0ee0
--- /dev/null
+++ b/candle-wasm-examples/t5/T5ModelEncoderWorker.js
@@ -0,0 +1,83 @@
+//load Candle Bert Module wasm module
+let init, ModelEncoder;
+
+async function fetchArrayBuffer(url) {
+ const cacheName = "t5-candle-cache";
+ const cache = await caches.open(cacheName);
+ const cachedResponse = await cache.match(url);
+ if (cachedResponse) {
+ const data = await cachedResponse.arrayBuffer();
+ return new Uint8Array(data);
+ }
+ const res = await fetch(url, { cache: "force-cache" });
+ cache.put(url, res.clone());
+ return new Uint8Array(await res.arrayBuffer());
+}
+class Encoder {
+ static instance = {};
+
+ static async getInstance(weightsURL, tokenizerURL, configURL, modelID) {
+ if (modelID.includes("quantized")) {
+ ({ default: init, ModelEncoder } = await import(
+ "./build/m-quantized.js"
+ ));
+ } else {
+ ({ default: init, ModelEncoder } = await import("./build/m.js"));
+ }
+ if (!this.instance[modelID]) {
+ await init();
+
+ self.postMessage({ status: "loading", message: "Loading Model" });
+ const [weightsArrayU8, tokenizerArrayU8, configArrayU8] =
+ await Promise.all([
+ fetchArrayBuffer(weightsURL),
+ fetchArrayBuffer(tokenizerURL),
+ fetchArrayBuffer(configURL),
+ ]);
+
+ this.instance[modelID] = new ModelEncoder(
+ weightsArrayU8,
+ tokenizerArrayU8,
+ configArrayU8
+ );
+ } else {
+ self.postMessage({ status: "ready", message: "Model Already Loaded" });
+ }
+ return this.instance[modelID];
+ }
+}
+
+self.addEventListener("message", async (event) => {
+ const {
+ weightsURL,
+ tokenizerURL,
+ configURL,
+ modelID,
+ sentences,
+ normalize_embeddings,
+ } = event.data;
+ try {
+ self.postMessage({ status: "ready", message: "Starting T5 Encoder" });
+ const model = await Encoder.getInstance(
+ weightsURL,
+ tokenizerURL,
+ configURL,
+ modelID
+ );
+ self.postMessage({
+ status: "encoding",
+ message: "Encoding Sentences",
+ });
+ const output = model.decode({
+ sentences: sentences,
+ normalize_embeddings: normalize_embeddings || true,
+ });
+ self.postMessage({
+ status: "complete",
+ message: "complete",
+ output: output,
+ });
+ } catch (e) {
+ self.postMessage({ error: e });
+ }
+});
diff --git a/candle-wasm-examples/t5/build-lib.sh b/candle-wasm-examples/t5/build-lib.sh
new file mode 100644
index 00000000..a311f69c
--- /dev/null
+++ b/candle-wasm-examples/t5/build-lib.sh
@@ -0,0 +1,3 @@
+cargo build --target wasm32-unknown-unknown --release
+wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web
+wasm-bindgen ../../target/wasm32-unknown-unknown/release/m-quantized.wasm --out-dir build --target web
diff --git a/candle-wasm-examples/t5/index.html b/candle-wasm-examples/t5/index.html
new file mode 100644
index 00000000..82b4e696
--- /dev/null
+++ b/candle-wasm-examples/t5/index.html
@@ -0,0 +1,276 @@
+<html>
+ <head>
+ <meta content="text/html;charset=utf-8" http-equiv="Content-Type" />
+ <title>Candle T5</title>
+ </head>
+
+ <body></body>
+</html>
+
+<!DOCTYPE html>
+<html>
+ <head>
+ <meta charset="UTF-8" />
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+ <style>
+ @import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap");
+
+ html,
+ body {
+ font-family: "Source Sans 3", sans-serif;
+ }
+ </style>
+ <style type="text/tailwindcss">
+ .link {
+ @apply underline hover:text-blue-500 hover:no-underline;
+ }
+ </style>
+ <script src="https://cdn.tailwindcss.com"></script>
+ <script type="module">
+ import {
+ getModelInfo,
+ MODELS,
+ extractEmbeddings,
+ generateText,
+ } from "./utils.js";
+
+ const t5ModelEncoderWorker = new Worker("./T5ModelEncoderWorker.js", {
+ type: "module",
+ });
+ const t5ModelConditionalGeneration = new Worker(
+ "./T5ModelConditionalGeneration.js",
+ { type: "module" }
+ );
+
+ const formEl = document.querySelector("#form");
+ const modelEl = document.querySelector("#model");
+ const promptEl = document.querySelector("#prompt");
+ const temperatureEl = document.querySelector("#temperature");
+ const toppEL = document.querySelector("#top-p");
+ const repeatPenaltyEl = document.querySelector("#repeat_penalty");
+ const seedEl = document.querySelector("#seed");
+ const outputEl = document.querySelector("#output-generation");
+ const tasksEl = document.querySelector("#tasks");
+ let selectedTaskID = "";
+
+ document.addEventListener("DOMContentLoaded", () => {
+ for (const [id, model] of Object.entries(MODELS)) {
+ const option = document.createElement("option");
+ option.value = id;
+ option.innerText = `${id} (${model.size})`;
+ modelEl.appendChild(option);
+ }
+ populateTasks(modelEl.value);
+ modelEl.addEventListener("change", (e) => {
+ populateTasks(e.target.value);
+ });
+ tasksEl.addEventListener("change", (e) => {
+ const task = e.target.value;
+ const modelID = modelEl.value;
+ promptEl.value = MODELS[modelID].tasks[task].prefix;
+ selectedTaskID = task;
+ });
+ });
+ function populateTasks(modelID) {
+ const tasks = MODELS[modelID].tasks;
+ tasksEl.innerHTML = "";
+ for (const [task, params] of Object.entries(tasks)) {
+ const div = document.createElement("div");
+ div.innerHTML = `
+ <input
+ type="radio"
+ name="task"
+ id="${task}"
+ class="font-light cursor-pointer"
+ value="${task}" />
+ <label for="${task}" class="cursor-pointer">
+ ${params.prefix}
+ </label>
+ `;
+ tasksEl.appendChild(div);
+ }
+ selectedTaskID = Object.keys(tasks)[0];
+ tasksEl.querySelector(`#${selectedTaskID}`).checked = true;
+ }
+ form.addEventListener("submit", (e) => {
+ e.preventDefault();
+
+ const promptText = promptEl.value;
+ const modelID = modelEl.value;
+ const { modelURL, configURL, tokenizerURL, maxLength } = getModelInfo(
+ modelID,
+ selectedTaskID
+ );
+ const params = {
+ temperature: Number(temperatureEl.value),
+ top_p: Number(toppEL.value),
+ repetition_penalty: Number(repeatPenaltyEl.value),
+ seed: BigInt(seedEl.value),
+ max_length: maxLength,
+ };
+ generateText(
+ t5ModelConditionalGeneration,
+ modelURL,
+ tokenizerURL,
+ configURL,
+ modelID,
+ promptText,
+ params,
+ (status) => {
+ if (status.status === "loading") {
+ outputEl.innerText = "Loading model...";
+ }
+ if (status.status === "decoding") {
+ outputEl.innerText = "Generating...";
+ }
+ }
+ ).then(({ output }) => {
+ outputEl.innerText = output.generation;
+ });
+ });
+ </script>
+ </head>
+
+ <body class="container max-w-4xl mx-auto p-4">
+ <main class="grid grid-cols-1 gap-8 relative">
+ <span class="absolute text-5xl -ml-[1em]"> 🕯️ </span>
+ <div>
+ <h1 class="text-5xl font-bold">Candle T5 Transformer</h1>
+ <h2 class="text-2xl font-bold">Rust/WASM Demo</h2>
+ <p class="max-w-lg">
+ This demo showcase Text-To-Text Transfer Transformer (<a
+ href="https://blog.research.google/2020/02/exploring-transfer-learning-with-t5.html"
+ target="_blank"
+ class="link"
+ >T5</a
+ >) models right in your browser, thanks to
+ <a
+ href="https://github.com/huggingface/candle/"
+ target="_blank"
+ class="link">
+ Candle
+ </a>
+ ML framework and rust/wasm. You can choose from a range of available
+ models, including
+ <a
+ href="https://huggingface.co/t5-small"
+ target="_blank"
+ class="link">
+ t5-small</a
+ >,
+ <a href="https://huggingface.co/t5-base" target="_blank" class="link"
+ >t5-base</a
+ >,
+ <a
+ href="https://huggingface.co/google/flan-t5-small"
+ target="_blank"
+ class="link"
+ >flan-t5-small</a
+ >
+ and several t5
+ <a
+ href="https://huggingface.co/lmz/candle-quantized-t5/tree/main"
+ target="_blank"
+ class="link">
+ t5 quantized gguf</a
+ >.
+ </p>
+ </div>
+
+ <div>
+ <label for="model" class="font-medium">Models Options: </label>
+ <select
+ id="model"
+ class="border-2 border-gray-500 rounded-md font-light"></select>
+ </div>
+
+ <div>
+ <h3 class="font-medium">Task Prefix:</h3>
+ <form id="tasks" class="flex flex-col gap-1 my-2"></form>
+ </div>
+ <form
+ id="form"
+ class="flex text-normal px-1 py-1 border border-gray-700 rounded-md items-center">
+ <input type="submit" hidden />
+ <input
+ type="text"
+ id="prompt"
+ class="font-light w-full px-3 py-2 mx-1 resize-none outline-none"
+ placeholder="Add prompt here, e.g. 'translate English to German: Today I'm going to eat Ice Cream'"
+ value="translate English to German: Today I'm going to eat Ice Cream" />
+ <button
+ class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 w-16 rounded disabled:bg-gray-300 disabled:cursor-not-allowed">
+ Run
+ </button>
+ </form>
+ <div class="grid grid-cols-3 max-w-md items-center gap-3">
+ <label class="text-sm font-medium" for="temperature">Temperature</label>
+ <input
+ type="range"
+ id="temperature"
+ name="temperature"
+ min="0"
+ max="2"
+ step="0.01"
+ value="0.00"
+ oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" />
+ <output
+ class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
+ 0.00</output
+ >
+ <label class="text-sm font-medium" for="top-p">Top-p</label>
+ <input
+ type="range"
+ id="top-p"
+ name="top-p"
+ min="0"
+ max="1"
+ step="0.01"
+ value="1.00"
+ oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" />
+ <output
+ class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
+ 1.00</output
+ >
+
+ <label class="text-sm font-medium" for="repeat_penalty"
+ >Repeat Penalty</label
+ >
+
+ <input
+ type="range"
+ id="repeat_penalty"
+ name="repeat_penalty"
+ min="-2"
+ max="2"
+ step="0.01"
+ value="1.10"
+ oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" />
+ <output
+ class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
+ >1.10</output
+ >
+ <label class="text-sm font-medium" for="seed">Seed</label>
+ <input
+ type="number"
+ id="seed"
+ name="seed"
+ value="299792458"
+ class="font-light border border-gray-700 text-right rounded-md p-2" />
+ <button
+ id="run"
+ onclick="document.querySelector('#seed').value = BigInt(Math.floor(Math.random() * 2**64-1))"
+ class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-1 w-[50px] rounded disabled:bg-gray-300 disabled:cursor-not-allowed text-sm">
+ Rand
+ </button>
+ </div>
+ <div>
+ <h3 class="font-medium">Generation:</h3>
+ <div
+ class="min-h-[250px] bg-slate-100 text-gray-500 p-4 rounded-md flex flex-col gap-2 text-lg">
+ <p id="output-generation" class="grid-rows-2">No output yet</p>
+ </div>
+ </div>
+ </main>
+ </body>
+</html>
diff --git a/candle-wasm-examples/t5/src/bin/m-quantized.rs b/candle-wasm-examples/t5/src/bin/m-quantized.rs
new file mode 100644
index 00000000..2f490b84
--- /dev/null
+++ b/candle-wasm-examples/t5/src/bin/m-quantized.rs
@@ -0,0 +1,205 @@
+use candle::{Device, Tensor};
+use candle_transformers::generation::LogitsProcessor;
+pub use candle_transformers::models::quantized_t5::{
+ Config, T5EncoderModel, T5ForConditionalGeneration, VarBuilder,
+};
+
+use candle_wasm_example_t5::console_log;
+use tokenizers::Tokenizer;
+use wasm_bindgen::prelude::*;
+
+#[wasm_bindgen]
+pub struct ModelEncoder {
+ model: T5EncoderModel,
+ tokenizer: Tokenizer,
+}
+#[wasm_bindgen]
+
+pub struct ModelConditionalGeneration {
+ model: T5ForConditionalGeneration,
+ tokenizer: Tokenizer,
+ config: Config,
+}
+
+#[wasm_bindgen]
+impl ModelConditionalGeneration {
+ #[wasm_bindgen(constructor)]
+ pub fn load(
+ weights: Vec<u8>,
+ tokenizer: Vec<u8>,
+ config: Vec<u8>,
+ ) -> Result<ModelConditionalGeneration, JsError> {
+ console_error_panic_hook::set_once();
+ console_log!("loading model");
+ let vb = VarBuilder::from_gguf_buffer(&weights)?;
+ let mut config: Config = serde_json::from_slice(&config)?;
+ let tokenizer =
+ Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
+ let model = T5ForConditionalGeneration::load(vb, &config)?;
+ config.use_cache = false;
+ Ok(Self {
+ model,
+ tokenizer,
+ config,
+ })
+ }
+ pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> {
+ let input: ConditionalGenerationParams =
+ serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
+ let device = &Device::Cpu;
+ self.model.clear_kv_cache();
+ let mut output_token_ids = [self.config.pad_token_id as u32].to_vec();
+ let prompt = input.prompt;
+ let repeat_penalty = input.repeat_penalty;
+ let repeat_last_n = input.repeat_last_n;
+ let seed = input.seed;
+ let max_length = usize::clamp(input.max_length.unwrap_or(512), 0, 512);
+ let temperature = if input.temperature <= 0. {
+ None
+ } else {
+ Some(input.temperature)
+ };
+ let top_p = if input.top_p <= 0. || input.top_p >= 1. {
+ None
+ } else {
+ Some(input.top_p)
+ };
+ let mut logits_processor = LogitsProcessor::new(seed, temperature, top_p);
+ let tokens = self
+ .tokenizer
+ .encode(prompt, true)
+ .map_err(|m| JsError::new(&m.to_string()))?
+ .get_ids()
+ .to_vec();
+
+ let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
+ let encoder_output = self.model.encode(&input_token_ids)?;
+ let mut decoded = String::new();
+ for index in 0.. {
+ if output_token_ids.len() > max_length {
+ break;
+ }
+ let decoder_token_ids = if index == 0 {
+ Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)?
+ } else {
+ let last_token = *output_token_ids.last().unwrap();
+ Tensor::new(&[last_token], device)?.unsqueeze(0)?
+ };
+ let logits = self
+ .model
+ .decode(&decoder_token_ids, &encoder_output)?
+ .squeeze(0)?;
+ let logits = if repeat_penalty == 1. {
+ logits
+ } else {
+ let start_at = output_token_ids.len().saturating_sub(repeat_last_n);
+ candle_transformers::utils::apply_repeat_penalty(
+ &logits,
+ repeat_penalty,
+ &output_token_ids[start_at..],
+ )?
+ };
+
+ let next_token_id = logits_processor.sample(&logits)?;
+ if next_token_id as usize == self.config.eos_token_id {
+ break;
+ }
+ output_token_ids.push(next_token_id);
+ if let Some(text) = self.tokenizer.id_to_token(next_token_id) {
+ let text = text.replace('▁', " ").replace("<0x0A>", "\n");
+ decoded += &text;
+ }
+ }
+ Ok(serde_wasm_bindgen::to_value(
+ &ConditionalGenerationOutput {
+ generation: decoded,
+ },
+ )?)
+ }
+}
+
+#[wasm_bindgen]
+impl ModelEncoder {
+ #[wasm_bindgen(constructor)]
+ pub fn load(
+ weights: Vec<u8>,
+ tokenizer: Vec<u8>,
+ config: Vec<u8>,
+ ) -> Result<ModelEncoder, JsError> {
+ console_error_panic_hook::set_once();
+ console_log!("loading model");
+ let vb = VarBuilder::from_gguf_buffer(&weights)?;
+ let mut config: Config = serde_json::from_slice(&config)?;
+ config.use_cache = false;
+ let tokenizer =
+ Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
+ let model = T5EncoderModel::load(vb, &config)?;
+ Ok(Self { model, tokenizer })
+ }
+
+ pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> {
+ let device = &Device::Cpu;
+ let input: DecoderParams =
+ serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
+
+ self.model.clear_kv_cache();
+ let sentences = input.sentences;
+ let normalize_embeddings = input.normalize_embeddings;
+ let n_sentences = sentences.len();
+ let mut all_embeddings = Vec::with_capacity(n_sentences);
+ for sentence in sentences {
+ let tokens = self
+ .tokenizer
+ .encode(sentence, true)
+ .map_err(|m| JsError::new(&m.to_string()))?
+ .get_ids()
+ .to_vec();
+ let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
+ let embeddings = self.model.forward(&token_ids)?;
+ console_log!("generated embeddings {:?}", embeddings.shape());
+ // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
+ let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
+ let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
+ let embeddings = if normalize_embeddings {
+ embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)?
+ } else {
+ embeddings
+ };
+ console_log!("{:?}", embeddings.shape());
+ all_embeddings.push(embeddings.squeeze(0)?.to_vec1::<f32>()?);
+ }
+
+ Ok(serde_wasm_bindgen::to_value(&DecoderOutput {
+ embeddings: all_embeddings,
+ })?)
+ }
+}
+
+#[derive(serde::Serialize, serde::Deserialize)]
+struct ConditionalGenerationOutput {
+ generation: String,
+}
+
+#[derive(serde::Serialize, serde::Deserialize)]
+struct DecoderOutput {
+ embeddings: Vec<Vec<f32>>,
+}
+
+#[derive(serde::Serialize, serde::Deserialize)]
+pub struct DecoderParams {
+ sentences: Vec<String>,
+ normalize_embeddings: bool,
+}
+#[derive(serde::Serialize, serde::Deserialize)]
+pub struct ConditionalGenerationParams {
+ prompt: String,
+ temperature: f64,
+ seed: u64,
+ top_p: f64,
+ repeat_penalty: f32,
+ repeat_last_n: usize,
+ max_length: Option<usize>,
+}
+fn main() {
+ console_error_panic_hook::set_once();
+}
diff --git a/candle-wasm-examples/t5/src/bin/m.rs b/candle-wasm-examples/t5/src/bin/m.rs
new file mode 100644
index 00000000..c82e00cd
--- /dev/null
+++ b/candle-wasm-examples/t5/src/bin/m.rs
@@ -0,0 +1,206 @@
+use candle::{DType, Device, Tensor};
+use candle_nn::VarBuilder;
+use candle_transformers::generation::LogitsProcessor;
+pub use candle_transformers::models::t5::{Config, T5EncoderModel, T5ForConditionalGeneration};
+use candle_wasm_example_t5::console_log;
+use tokenizers::Tokenizer;
+use wasm_bindgen::prelude::*;
+#[wasm_bindgen]
+pub struct ModelEncoder {
+ model: T5EncoderModel,
+ tokenizer: Tokenizer,
+}
+#[wasm_bindgen]
+
+pub struct ModelConditionalGeneration {
+ model: T5ForConditionalGeneration,
+ tokenizer: Tokenizer,
+ config: Config,
+}
+
+#[wasm_bindgen]
+impl ModelConditionalGeneration {
+ #[wasm_bindgen(constructor)]
+ pub fn load(
+ weights: Vec<u8>,
+ tokenizer: Vec<u8>,
+ config: Vec<u8>,
+ ) -> Result<ModelConditionalGeneration, JsError> {
+ console_error_panic_hook::set_once();
+ console_log!("loading model");
+ let device = &Device::Cpu;
+ let weights = safetensors::tensor::SafeTensors::deserialize(&weights)?;
+ let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, device);
+ let mut config: Config = serde_json::from_slice(&config)?;
+ let tokenizer =
+ Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
+ let model = T5ForConditionalGeneration::load(vb, &config)?;
+ config.use_cache = false;
+ Ok(Self {
+ model,
+ tokenizer,
+ config,
+ })
+ }
+ pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> {
+ let input: ConditionalGenerationParams =
+ serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
+ let device = &Device::Cpu;
+ self.model.clear_kv_cache();
+ let mut output_token_ids = [self.config.pad_token_id as u32].to_vec();
+ let prompt = input.prompt;
+ let repeat_penalty = input.repeat_penalty;
+ let repeat_last_n = input.repeat_last_n;
+ let seed = input.seed;
+ let max_length = usize::clamp(input.max_length.unwrap_or(512), 0, 512);
+ let temperature = if input.temperature <= 0. {
+ None
+ } else {
+ Some(input.temperature)
+ };
+ let top_p = if input.top_p <= 0. || input.top_p >= 1. {
+ None
+ } else {
+ Some(input.top_p)
+ };
+ let mut logits_processor = LogitsProcessor::new(seed, temperature, top_p);
+ let tokens = self
+ .tokenizer
+ .encode(prompt, true)
+ .map_err(|m| JsError::new(&m.to_string()))?
+ .get_ids()
+ .to_vec();
+
+ let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
+ let encoder_output = self.model.encode(&input_token_ids)?;
+ let mut decoded = String::new();
+ for index in 0.. {
+ if output_token_ids.len() > max_length {
+ break;
+ }
+ let decoder_token_ids = if index == 0 {
+ Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)?
+ } else {
+ let last_token = *output_token_ids.last().unwrap();
+ Tensor::new(&[last_token], device)?.unsqueeze(0)?
+ };
+ let logits = self
+ .model
+ .decode(&decoder_token_ids, &encoder_output)?
+ .squeeze(0)?;
+ let logits = if repeat_penalty == 1. {
+ logits
+ } else {
+ let start_at = output_token_ids.len().saturating_sub(repeat_last_n);
+ candle_transformers::utils::apply_repeat_penalty(
+ &logits,
+ repeat_penalty,
+ &output_token_ids[start_at..],
+ )?
+ };
+
+ let next_token_id = logits_processor.sample(&logits)?;
+ if next_token_id as usize == self.config.eos_token_id {
+ break;
+ }
+ output_token_ids.push(next_token_id);
+ if let Some(text) = self.tokenizer.id_to_token(next_token_id) {
+ let text = text.replace('▁', " ").replace("<0x0A>", "\n");
+ decoded += &text;
+ }
+ }
+ Ok(serde_wasm_bindgen::to_value(
+ &ConditionalGenerationOutput {
+ generation: decoded,
+ },
+ )?)
+ }
+}
+
+#[wasm_bindgen]
+impl ModelEncoder {
+ #[wasm_bindgen(constructor)]
+ pub fn load(
+ weights: Vec<u8>,
+ tokenizer: Vec<u8>,
+ config: Vec<u8>,
+ ) -> Result<ModelEncoder, JsError> {
+ console_error_panic_hook::set_once();
+ console_log!("loading model");
+ let device = &Device::Cpu;
+ let weights = safetensors::tensor::SafeTensors::deserialize(&weights)?;
+ let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, device);
+ let mut config: Config = serde_json::from_slice(&config)?;
+ config.use_cache = false;
+ let tokenizer =
+ Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
+ let model = T5EncoderModel::load(vb, &config)?;
+ Ok(Self { model, tokenizer })
+ }
+
+ pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> {
+ let device = &Device::Cpu;
+ let input: DecoderParams =
+ serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
+
+ self.model.clear_kv_cache();
+ let sentences = input.sentences;
+ let normalize_embeddings = input.normalize_embeddings;
+ let n_sentences = sentences.len();
+ let mut all_embeddings = Vec::with_capacity(n_sentences);
+ for sentence in sentences {
+ let tokens = self
+ .tokenizer
+ .encode(sentence, true)
+ .map_err(|m| JsError::new(&m.to_string()))?
+ .get_ids()
+ .to_vec();
+ let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
+ let embeddings = self.model.forward(&token_ids)?;
+ console_log!("generated embeddings {:?}", embeddings.shape());
+ // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
+ let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
+ let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
+ let embeddings = if normalize_embeddings {
+ embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)?
+ } else {
+ embeddings
+ };
+ console_log!("{:?}", embeddings.shape());
+ all_embeddings.push(embeddings.squeeze(0)?.to_vec1::<f32>()?);
+ }
+
+ Ok(serde_wasm_bindgen::to_value(&DecoderOutput {
+ embeddings: all_embeddings,
+ })?)
+ }
+}
+
+#[derive(serde::Serialize, serde::Deserialize)]
+struct ConditionalGenerationOutput {
+ generation: String,
+}
+
+#[derive(serde::Serialize, serde::Deserialize)]
+struct DecoderOutput {
+ embeddings: Vec<Vec<f32>>,
+}
+
+#[derive(serde::Serialize, serde::Deserialize)]
+pub struct DecoderParams {
+ sentences: Vec<String>,
+ normalize_embeddings: bool,
+}
+#[derive(serde::Serialize, serde::Deserialize)]
+pub struct ConditionalGenerationParams {
+ prompt: String,
+ temperature: f64,
+ seed: u64,
+ top_p: f64,
+ repeat_penalty: f32,
+ repeat_last_n: usize,
+ max_length: Option<usize>,
+}
+fn main() {
+ console_error_panic_hook::set_once();
+}
diff --git a/candle-wasm-examples/t5/src/lib.rs b/candle-wasm-examples/t5/src/lib.rs
new file mode 100644
index 00000000..cb15633c
--- /dev/null
+++ b/candle-wasm-examples/t5/src/lib.rs
@@ -0,0 +1,16 @@
+use wasm_bindgen::prelude::*;
+
+#[wasm_bindgen]
+extern "C" {
+ // Use `js_namespace` here to bind `console.log(..)` instead of just
+ // `log(..)`
+ #[wasm_bindgen(js_namespace = console)]
+ pub fn log(s: &str);
+}
+
+#[macro_export]
+macro_rules! console_log {
+ // Note that this is using the `log` function imported above during
+ // `bare_bones`
+ ($($t:tt)*) => ($crate::log(&format_args!($($t)*).to_string()))
+}
diff --git a/candle-wasm-examples/t5/utils.js b/candle-wasm-examples/t5/utils.js
new file mode 100644
index 00000000..e45e7d1b
--- /dev/null
+++ b/candle-wasm-examples/t5/utils.js
@@ -0,0 +1,168 @@
+export async function extractEmbeddings(
+ worker,
+ weightsURL,
+ tokenizerURL,
+ configURL,
+ modelID,
+ sentences,
+ updateStatus,
+ normalize_embeddings = true
+) {
+ return new Promise((resolve, reject) => {
+ worker.postMessage({
+ weightsURL,
+ tokenizerURL,
+ configURL,
+ modelID,
+ sentences,
+ normalize_embeddings,
+ });
+ function messageHandler(event) {
+ if ("error" in event.data) {
+ worker.removeEventListener("message", messageHandler);
+ reject(new Error(event.data.error));
+ }
+ if (event.data.status === "complete") {
+ worker.removeEventListener("message", messageHandler);
+ resolve(event.data);
+ }
+ if (updateStatus) updateStatus(event.data);
+ }
+ worker.addEventListener("message", messageHandler);
+ });
+}
+
+export async function generateText(
+ worker,
+ weightsURL,
+ tokenizerURL,
+ configURL,
+ modelID,
+ prompt,
+ params,
+ updateStatus
+) {
+ return new Promise((resolve, reject) => {
+ worker.postMessage({
+ weightsURL,
+ tokenizerURL,
+ configURL,
+ modelID,
+ prompt,
+ params,
+ });
+ function messageHandler(event) {
+ if ("error" in event.data) {
+ worker.removeEventListener("message", messageHandler);
+ reject(new Error(event.data.error));
+ }
+ if (event.data.status === "complete") {
+ worker.removeEventListener("message", messageHandler);
+ resolve(event.data);
+ }
+ if (updateStatus) updateStatus(event.data);
+ }
+ worker.addEventListener("message", messageHandler);
+ });
+}
+export const MODELS = {
+ t5_small_quantized: {
+ size: "102 MB",
+ base_url: "https://huggingface.co/lmz/candle-quantized-t5/resolve/main/",
+ model: "model.gguf",
+ tokenizer: "tokenizer.json",
+ config: "config.json",
+ tasks: {
+ translation_en_to_de: {
+ prefix: "translate English to German: ",
+ max_length: 300,
+ },
+ translation_en_to_fr: {
+ prefix: "translate English to French: ",
+ max_length: 300,
+ },
+ translation_en_to_ro: {
+ prefix: "translate English to Romanian: ",
+ max_length: 300,
+ },
+ summarization: { prefix: "summarize: ", max_length: 200 },
+ },
+ },
+ t5_small: {
+ size: "242 MB",
+ base_url: "https://huggingface.co/t5-small/resolve/main/",
+ model: "model.safetensors",
+ tokenizer: "tokenizer.json",
+ config: "config.json",
+ tasks: {
+ translation_en_to_de: {
+ prefix: "translate English to German: ",
+ max_length: 300,
+ },
+ translation_en_to_fr: {
+ prefix: "translate English to French: ",
+ max_length: 300,
+ },
+ translation_en_to_ro: {
+ prefix: "translate English to Romanian: ",
+ max_length: 300,
+ },
+ summarization: { prefix: "summarize: ", max_length: 200 },
+ },
+ },
+ flan_t5_small: {
+ size: "308 MB",
+ base_url:
+ "https://huggingface.co/google/flan-t5-small/resolve/refs%2Fpr%2F14/",
+ model: "model.safetensors",
+ tokenizer: "tokenizer.json",
+ config: "config.json",
+ tasks: {
+ translation_en_to_de: {
+ prefix: "translate English to German: ",
+ max_length: 300,
+ },
+ translation_en_to_fr: {
+ prefix: "translate English to French: ",
+ max_length: 300,
+ },
+ translation_en_to_ro: {
+ prefix: "translate English to Romanian: ",
+ max_length: 300,
+ },
+ summarization: { prefix: "summarize: ", max_length: 200 },
+ },
+ },
+
+ flan_t5_base_quantized: {
+ size: "360 MB",
+ base_url: "https://huggingface.co/lmz/candle-quantized-t5/resolve/main/",
+ model: "model-flan-t5-base.gguf",
+ tokenizer: "tokenizer.json",
+ config: "config-flan-t5-base.json",
+ tasks: {
+ translation_en_to_de: {
+ prefix: "translate English to German: ",
+ max_length: 300,
+ },
+ translation_en_to_fr: {
+ prefix: "translate English to French: ",
+ max_length: 300,
+ },
+ translation_en_to_ro: {
+ prefix: "translate English to Romanian: ",
+ max_length: 300,
+ },
+ summarization: { prefix: "summarize: ", max_length: 200 },
+ },
+ },
+};
+export function getModelInfo(id, taskID) {
+ const model = MODELS[id];
+ return {
+ modelURL: model.base_url + model.model,
+ configURL: model.base_url + model.config,
+ tokenizerURL: model.base_url + model.tokenizer,
+ maxLength: model.tasks[taskID].max_length,
+ };
+}