diff options
author | Radamés Ajna <radamajna@gmail.com> | 2023-09-22 07:31:10 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-22 15:31:10 +0100 |
commit | 19e52e5007e10816eefb2e1a1968be760c5d11a4 (patch) | |
tree | f0d9cad35d261c3a28d2c4fa5a8ff1af84a4631f | |
parent | 8601537e31af610c0bbd32ee8c8ee17ed802427c (diff) | |
download | candle-19e52e5007e10816eefb2e1a1968be760c5d11a4.tar.gz candle-19e52e5007e10816eefb2e1a1968be760c5d11a4.tar.bz2 candle-19e52e5007e10816eefb2e1a1968be760c5d11a4.zip |
T5 Wasm (#918)
* init t5 wasm model
* split workers for each model
* clean up
* add some ui
* readme
* index
* typo
* remove cache param, clear_kv_cache
* add max_length as param
* add model tasks option to ui
* add method to load quantized gguf from buffer
* Add quantized wasm module
* add quantized models to UI, dynamic import wasms
* link to quantized
* fix copy
* fix ModelEncoder
* fix README.md
-rw-r--r-- | Cargo.toml | 1 | ||||
-rw-r--r-- | candle-transformers/src/models/quantized_t5.rs | 15 | ||||
-rw-r--r-- | candle-wasm-examples/t5/Cargo.toml | 33 | ||||
-rw-r--r-- | candle-wasm-examples/t5/README.md | 32 | ||||
-rw-r--r-- | candle-wasm-examples/t5/T5ModelConditionalGeneration.js | 93 | ||||
-rw-r--r-- | candle-wasm-examples/t5/T5ModelEncoderWorker.js | 83 | ||||
-rw-r--r-- | candle-wasm-examples/t5/build-lib.sh | 3 | ||||
-rw-r--r-- | candle-wasm-examples/t5/index.html | 276 | ||||
-rw-r--r-- | candle-wasm-examples/t5/src/bin/m-quantized.rs | 205 | ||||
-rw-r--r-- | candle-wasm-examples/t5/src/bin/m.rs | 206 | ||||
-rw-r--r-- | candle-wasm-examples/t5/src/lib.rs | 16 | ||||
-rw-r--r-- | candle-wasm-examples/t5/utils.js | 168 |
12 files changed, 1131 insertions, 0 deletions
@@ -12,6 +12,7 @@ members = [ "candle-wasm-examples/whisper", "candle-wasm-examples/yolo", "candle-wasm-examples/bert", + "candle-wasm-examples/t5", ] exclude = ["candle-flash-attn", "candle-kernels"] resolver = "2" diff --git a/candle-transformers/src/models/quantized_t5.rs b/candle-transformers/src/models/quantized_t5.rs index a10c3b80..a86dfcb3 100644 --- a/candle-transformers/src/models/quantized_t5.rs +++ b/candle-transformers/src/models/quantized_t5.rs @@ -30,6 +30,21 @@ impl VarBuilder { }) } + pub fn from_gguf_buffer(buffer: &[u8]) -> Result<Self> { + let mut cursor = std::io::Cursor::new(buffer); + let content = candle::quantized::gguf_file::Content::read(&mut cursor)?; + let mut data = std::collections::HashMap::new(); + for tensor_name in content.tensor_infos.keys() { + let tensor = content.tensor(&mut cursor, tensor_name)?; + data.insert(tensor_name.to_string(), Arc::new(tensor)); + } + Ok(Self { + data: Arc::new(data), + path: Vec::new(), + device: Device::Cpu, + }) + } + fn pp<S: ToString>(&self, s: S) -> Self { let mut path = self.path.clone(); path.push(s.to_string()); diff --git a/candle-wasm-examples/t5/Cargo.toml b/candle-wasm-examples/t5/Cargo.toml new file mode 100644 index 00000000..b011d2e5 --- /dev/null +++ b/candle-wasm-examples/t5/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "candle-wasm-example-t5" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true + +[dependencies] +candle = { path = "../../candle-core", version = "0.2.2", package = "candle-core" } +candle-nn = { path = "../../candle-nn", version = "0.2.2" } +candle-transformers = { path = "../../candle-transformers", version = "0.2.2" } +num-traits = { workspace = true } +tokenizers = { workspace = true, features = ["unstable_wasm"] } + +# App crates. +anyhow = { workspace = true } +byteorder = { workspace = true } +log = { workspace = true } +rand = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +safetensors = { workspace = true } + +# Wasm specific crates. +console_error_panic_hook = "0.1.7" +getrandom = { version = "0.2", features = ["js"] } +gloo = "0.8" +js-sys = "0.3.64" +wasm-bindgen = "0.2.87" +serde-wasm-bindgen = "0.6.0" diff --git a/candle-wasm-examples/t5/README.md b/candle-wasm-examples/t5/README.md new file mode 100644 index 00000000..9a9f5bce --- /dev/null +++ b/candle-wasm-examples/t5/README.md @@ -0,0 +1,32 @@ +## Running T5 with Candle and WASM + +Here, we provide two examples of how to run Bert using a Candle-compiled WASM binary and runtime. + +### Vanilla JS and WebWorkers + +To build and test the UI made in Vanilla JS and WebWorkers, first we need to build the WASM library: + +```bash +sh build-lib.sh +``` + +This will bundle the library under `./build` and we can import it inside our WebWorker like a normal JS module: + +```js +import init, { ModelConditionalGeneration, ModelEncoder } from "./build/m.js"; +``` + +For the quantized version, we need to import the quantized module: + +```js +import init, { ModelConditionalGeneration, ModelEncoder } from "./build/m-quantized.js"; +``` + +The full example can be found under `./index.html`. All needed assets are fetched from the web, so no need to download anything. +Finally, you can preview the example by running a local HTTP server. For example: + +```bash +python -m http.server +``` + +Then open `http://localhost:8000/index.html` in your browser. diff --git a/candle-wasm-examples/t5/T5ModelConditionalGeneration.js b/candle-wasm-examples/t5/T5ModelConditionalGeneration.js new file mode 100644 index 00000000..5f94c19a --- /dev/null +++ b/candle-wasm-examples/t5/T5ModelConditionalGeneration.js @@ -0,0 +1,93 @@ +//load Candle Bert Module wasm module +let init, ModelConditionalGeneration; + +async function fetchArrayBuffer(url) { + const cacheName = "t5-candle-cache"; + const cache = await caches.open(cacheName); + const cachedResponse = await cache.match(url); + if (cachedResponse) { + const data = await cachedResponse.arrayBuffer(); + return new Uint8Array(data); + } + const res = await fetch(url, { cache: "force-cache" }); + cache.put(url, res.clone()); + return new Uint8Array(await res.arrayBuffer()); +} +class ConditionalGeneration { + static instance = {}; + + static async getInstance(weightsURL, tokenizerURL, configURL, modelID) { + if (modelID.includes("quantized")) { + ({ default: init, ModelConditionalGeneration } = await import( + "./build/m-quantized.js" + )); + } else { + ({ default: init, ModelConditionalGeneration } = await import( + "./build/m.js" + )); + } + if (!this.instance[modelID]) { + await init(); + + self.postMessage({ status: "loading", message: "Loading Model" }); + const [weightsArrayU8, tokenizerArrayU8, configArrayU8] = + await Promise.all([ + fetchArrayBuffer(weightsURL), + fetchArrayBuffer(tokenizerURL), + fetchArrayBuffer(configURL), + ]); + + this.instance[modelID] = new ModelConditionalGeneration( + weightsArrayU8, + tokenizerArrayU8, + configArrayU8 + ); + } else { + self.postMessage({ status: "ready", message: "Model Already Loaded" }); + } + return this.instance[modelID]; + } +} + +self.addEventListener("message", async (event) => { + const { weightsURL, tokenizerURL, configURL, modelID, prompt, params } = + event.data; + let { + temperature = 0.0, + seed = 299792458, + repeat_penalty = 1.1, + repeat_last_n = 64, + top_p = 1, + } = { ...params }; + try { + self.postMessage({ + status: "ready", + message: "Starting T5 Conditional Generation", + }); + const model = await ConditionalGeneration.getInstance( + weightsURL, + tokenizerURL, + configURL, + modelID + ); + self.postMessage({ + status: "decoding", + message: "Decoding Prompt", + }); + const output = model.decode({ + prompt, + temperature, + seed, + top_p, + repeat_penalty, + repeat_last_n, + }); + self.postMessage({ + status: "complete", + message: "complete", + output: output, + }); + } catch (e) { + self.postMessage({ error: e }); + } +}); diff --git a/candle-wasm-examples/t5/T5ModelEncoderWorker.js b/candle-wasm-examples/t5/T5ModelEncoderWorker.js new file mode 100644 index 00000000..a83b0ee0 --- /dev/null +++ b/candle-wasm-examples/t5/T5ModelEncoderWorker.js @@ -0,0 +1,83 @@ +//load Candle Bert Module wasm module +let init, ModelEncoder; + +async function fetchArrayBuffer(url) { + const cacheName = "t5-candle-cache"; + const cache = await caches.open(cacheName); + const cachedResponse = await cache.match(url); + if (cachedResponse) { + const data = await cachedResponse.arrayBuffer(); + return new Uint8Array(data); + } + const res = await fetch(url, { cache: "force-cache" }); + cache.put(url, res.clone()); + return new Uint8Array(await res.arrayBuffer()); +} +class Encoder { + static instance = {}; + + static async getInstance(weightsURL, tokenizerURL, configURL, modelID) { + if (modelID.includes("quantized")) { + ({ default: init, ModelEncoder } = await import( + "./build/m-quantized.js" + )); + } else { + ({ default: init, ModelEncoder } = await import("./build/m.js")); + } + if (!this.instance[modelID]) { + await init(); + + self.postMessage({ status: "loading", message: "Loading Model" }); + const [weightsArrayU8, tokenizerArrayU8, configArrayU8] = + await Promise.all([ + fetchArrayBuffer(weightsURL), + fetchArrayBuffer(tokenizerURL), + fetchArrayBuffer(configURL), + ]); + + this.instance[modelID] = new ModelEncoder( + weightsArrayU8, + tokenizerArrayU8, + configArrayU8 + ); + } else { + self.postMessage({ status: "ready", message: "Model Already Loaded" }); + } + return this.instance[modelID]; + } +} + +self.addEventListener("message", async (event) => { + const { + weightsURL, + tokenizerURL, + configURL, + modelID, + sentences, + normalize_embeddings, + } = event.data; + try { + self.postMessage({ status: "ready", message: "Starting T5 Encoder" }); + const model = await Encoder.getInstance( + weightsURL, + tokenizerURL, + configURL, + modelID + ); + self.postMessage({ + status: "encoding", + message: "Encoding Sentences", + }); + const output = model.decode({ + sentences: sentences, + normalize_embeddings: normalize_embeddings || true, + }); + self.postMessage({ + status: "complete", + message: "complete", + output: output, + }); + } catch (e) { + self.postMessage({ error: e }); + } +}); diff --git a/candle-wasm-examples/t5/build-lib.sh b/candle-wasm-examples/t5/build-lib.sh new file mode 100644 index 00000000..a311f69c --- /dev/null +++ b/candle-wasm-examples/t5/build-lib.sh @@ -0,0 +1,3 @@ +cargo build --target wasm32-unknown-unknown --release +wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web +wasm-bindgen ../../target/wasm32-unknown-unknown/release/m-quantized.wasm --out-dir build --target web diff --git a/candle-wasm-examples/t5/index.html b/candle-wasm-examples/t5/index.html new file mode 100644 index 00000000..82b4e696 --- /dev/null +++ b/candle-wasm-examples/t5/index.html @@ -0,0 +1,276 @@ +<html> + <head> + <meta content="text/html;charset=utf-8" http-equiv="Content-Type" /> + <title>Candle T5</title> + </head> + + <body></body> +</html> + +<!DOCTYPE html> +<html> + <head> + <meta charset="UTF-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + <style> + @import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap"); + + html, + body { + font-family: "Source Sans 3", sans-serif; + } + </style> + <style type="text/tailwindcss"> + .link { + @apply underline hover:text-blue-500 hover:no-underline; + } + </style> + <script src="https://cdn.tailwindcss.com"></script> + <script type="module"> + import { + getModelInfo, + MODELS, + extractEmbeddings, + generateText, + } from "./utils.js"; + + const t5ModelEncoderWorker = new Worker("./T5ModelEncoderWorker.js", { + type: "module", + }); + const t5ModelConditionalGeneration = new Worker( + "./T5ModelConditionalGeneration.js", + { type: "module" } + ); + + const formEl = document.querySelector("#form"); + const modelEl = document.querySelector("#model"); + const promptEl = document.querySelector("#prompt"); + const temperatureEl = document.querySelector("#temperature"); + const toppEL = document.querySelector("#top-p"); + const repeatPenaltyEl = document.querySelector("#repeat_penalty"); + const seedEl = document.querySelector("#seed"); + const outputEl = document.querySelector("#output-generation"); + const tasksEl = document.querySelector("#tasks"); + let selectedTaskID = ""; + + document.addEventListener("DOMContentLoaded", () => { + for (const [id, model] of Object.entries(MODELS)) { + const option = document.createElement("option"); + option.value = id; + option.innerText = `${id} (${model.size})`; + modelEl.appendChild(option); + } + populateTasks(modelEl.value); + modelEl.addEventListener("change", (e) => { + populateTasks(e.target.value); + }); + tasksEl.addEventListener("change", (e) => { + const task = e.target.value; + const modelID = modelEl.value; + promptEl.value = MODELS[modelID].tasks[task].prefix; + selectedTaskID = task; + }); + }); + function populateTasks(modelID) { + const tasks = MODELS[modelID].tasks; + tasksEl.innerHTML = ""; + for (const [task, params] of Object.entries(tasks)) { + const div = document.createElement("div"); + div.innerHTML = ` + <input + type="radio" + name="task" + id="${task}" + class="font-light cursor-pointer" + value="${task}" /> + <label for="${task}" class="cursor-pointer"> + ${params.prefix} + </label> + `; + tasksEl.appendChild(div); + } + selectedTaskID = Object.keys(tasks)[0]; + tasksEl.querySelector(`#${selectedTaskID}`).checked = true; + } + form.addEventListener("submit", (e) => { + e.preventDefault(); + + const promptText = promptEl.value; + const modelID = modelEl.value; + const { modelURL, configURL, tokenizerURL, maxLength } = getModelInfo( + modelID, + selectedTaskID + ); + const params = { + temperature: Number(temperatureEl.value), + top_p: Number(toppEL.value), + repetition_penalty: Number(repeatPenaltyEl.value), + seed: BigInt(seedEl.value), + max_length: maxLength, + }; + generateText( + t5ModelConditionalGeneration, + modelURL, + tokenizerURL, + configURL, + modelID, + promptText, + params, + (status) => { + if (status.status === "loading") { + outputEl.innerText = "Loading model..."; + } + if (status.status === "decoding") { + outputEl.innerText = "Generating..."; + } + } + ).then(({ output }) => { + outputEl.innerText = output.generation; + }); + }); + </script> + </head> + + <body class="container max-w-4xl mx-auto p-4"> + <main class="grid grid-cols-1 gap-8 relative"> + <span class="absolute text-5xl -ml-[1em]"> 🕯️ </span> + <div> + <h1 class="text-5xl font-bold">Candle T5 Transformer</h1> + <h2 class="text-2xl font-bold">Rust/WASM Demo</h2> + <p class="max-w-lg"> + This demo showcase Text-To-Text Transfer Transformer (<a + href="https://blog.research.google/2020/02/exploring-transfer-learning-with-t5.html" + target="_blank" + class="link" + >T5</a + >) models right in your browser, thanks to + <a + href="https://github.com/huggingface/candle/" + target="_blank" + class="link"> + Candle + </a> + ML framework and rust/wasm. You can choose from a range of available + models, including + <a + href="https://huggingface.co/t5-small" + target="_blank" + class="link"> + t5-small</a + >, + <a href="https://huggingface.co/t5-base" target="_blank" class="link" + >t5-base</a + >, + <a + href="https://huggingface.co/google/flan-t5-small" + target="_blank" + class="link" + >flan-t5-small</a + > + and several t5 + <a + href="https://huggingface.co/lmz/candle-quantized-t5/tree/main" + target="_blank" + class="link"> + t5 quantized gguf</a + >. + </p> + </div> + + <div> + <label for="model" class="font-medium">Models Options: </label> + <select + id="model" + class="border-2 border-gray-500 rounded-md font-light"></select> + </div> + + <div> + <h3 class="font-medium">Task Prefix:</h3> + <form id="tasks" class="flex flex-col gap-1 my-2"></form> + </div> + <form + id="form" + class="flex text-normal px-1 py-1 border border-gray-700 rounded-md items-center"> + <input type="submit" hidden /> + <input + type="text" + id="prompt" + class="font-light w-full px-3 py-2 mx-1 resize-none outline-none" + placeholder="Add prompt here, e.g. 'translate English to German: Today I'm going to eat Ice Cream'" + value="translate English to German: Today I'm going to eat Ice Cream" /> + <button + class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 w-16 rounded disabled:bg-gray-300 disabled:cursor-not-allowed"> + Run + </button> + </form> + <div class="grid grid-cols-3 max-w-md items-center gap-3"> + <label class="text-sm font-medium" for="temperature">Temperature</label> + <input + type="range" + id="temperature" + name="temperature" + min="0" + max="2" + step="0.01" + value="0.00" + oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" /> + <output + class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"> + 0.00</output + > + <label class="text-sm font-medium" for="top-p">Top-p</label> + <input + type="range" + id="top-p" + name="top-p" + min="0" + max="1" + step="0.01" + value="1.00" + oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" /> + <output + class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"> + 1.00</output + > + + <label class="text-sm font-medium" for="repeat_penalty" + >Repeat Penalty</label + > + + <input + type="range" + id="repeat_penalty" + name="repeat_penalty" + min="-2" + max="2" + step="0.01" + value="1.10" + oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" /> + <output + class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md" + >1.10</output + > + <label class="text-sm font-medium" for="seed">Seed</label> + <input + type="number" + id="seed" + name="seed" + value="299792458" + class="font-light border border-gray-700 text-right rounded-md p-2" /> + <button + id="run" + onclick="document.querySelector('#seed').value = BigInt(Math.floor(Math.random() * 2**64-1))" + class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-1 w-[50px] rounded disabled:bg-gray-300 disabled:cursor-not-allowed text-sm"> + Rand + </button> + </div> + <div> + <h3 class="font-medium">Generation:</h3> + <div + class="min-h-[250px] bg-slate-100 text-gray-500 p-4 rounded-md flex flex-col gap-2 text-lg"> + <p id="output-generation" class="grid-rows-2">No output yet</p> + </div> + </div> + </main> + </body> +</html> diff --git a/candle-wasm-examples/t5/src/bin/m-quantized.rs b/candle-wasm-examples/t5/src/bin/m-quantized.rs new file mode 100644 index 00000000..2f490b84 --- /dev/null +++ b/candle-wasm-examples/t5/src/bin/m-quantized.rs @@ -0,0 +1,205 @@ +use candle::{Device, Tensor}; +use candle_transformers::generation::LogitsProcessor; +pub use candle_transformers::models::quantized_t5::{ + Config, T5EncoderModel, T5ForConditionalGeneration, VarBuilder, +}; + +use candle_wasm_example_t5::console_log; +use tokenizers::Tokenizer; +use wasm_bindgen::prelude::*; + +#[wasm_bindgen] +pub struct ModelEncoder { + model: T5EncoderModel, + tokenizer: Tokenizer, +} +#[wasm_bindgen] + +pub struct ModelConditionalGeneration { + model: T5ForConditionalGeneration, + tokenizer: Tokenizer, + config: Config, +} + +#[wasm_bindgen] +impl ModelConditionalGeneration { + #[wasm_bindgen(constructor)] + pub fn load( + weights: Vec<u8>, + tokenizer: Vec<u8>, + config: Vec<u8>, + ) -> Result<ModelConditionalGeneration, JsError> { + console_error_panic_hook::set_once(); + console_log!("loading model"); + let vb = VarBuilder::from_gguf_buffer(&weights)?; + let mut config: Config = serde_json::from_slice(&config)?; + let tokenizer = + Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?; + let model = T5ForConditionalGeneration::load(vb, &config)?; + config.use_cache = false; + Ok(Self { + model, + tokenizer, + config, + }) + } + pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> { + let input: ConditionalGenerationParams = + serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?; + let device = &Device::Cpu; + self.model.clear_kv_cache(); + let mut output_token_ids = [self.config.pad_token_id as u32].to_vec(); + let prompt = input.prompt; + let repeat_penalty = input.repeat_penalty; + let repeat_last_n = input.repeat_last_n; + let seed = input.seed; + let max_length = usize::clamp(input.max_length.unwrap_or(512), 0, 512); + let temperature = if input.temperature <= 0. { + None + } else { + Some(input.temperature) + }; + let top_p = if input.top_p <= 0. || input.top_p >= 1. { + None + } else { + Some(input.top_p) + }; + let mut logits_processor = LogitsProcessor::new(seed, temperature, top_p); + let tokens = self + .tokenizer + .encode(prompt, true) + .map_err(|m| JsError::new(&m.to_string()))? + .get_ids() + .to_vec(); + + let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; + let encoder_output = self.model.encode(&input_token_ids)?; + let mut decoded = String::new(); + for index in 0.. { + if output_token_ids.len() > max_length { + break; + } + let decoder_token_ids = if index == 0 { + Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)? + } else { + let last_token = *output_token_ids.last().unwrap(); + Tensor::new(&[last_token], device)?.unsqueeze(0)? + }; + let logits = self + .model + .decode(&decoder_token_ids, &encoder_output)? + .squeeze(0)?; + let logits = if repeat_penalty == 1. { + logits + } else { + let start_at = output_token_ids.len().saturating_sub(repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + repeat_penalty, + &output_token_ids[start_at..], + )? + }; + + let next_token_id = logits_processor.sample(&logits)?; + if next_token_id as usize == self.config.eos_token_id { + break; + } + output_token_ids.push(next_token_id); + if let Some(text) = self.tokenizer.id_to_token(next_token_id) { + let text = text.replace('▁', " ").replace("<0x0A>", "\n"); + decoded += &text; + } + } + Ok(serde_wasm_bindgen::to_value( + &ConditionalGenerationOutput { + generation: decoded, + }, + )?) + } +} + +#[wasm_bindgen] +impl ModelEncoder { + #[wasm_bindgen(constructor)] + pub fn load( + weights: Vec<u8>, + tokenizer: Vec<u8>, + config: Vec<u8>, + ) -> Result<ModelEncoder, JsError> { + console_error_panic_hook::set_once(); + console_log!("loading model"); + let vb = VarBuilder::from_gguf_buffer(&weights)?; + let mut config: Config = serde_json::from_slice(&config)?; + config.use_cache = false; + let tokenizer = + Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?; + let model = T5EncoderModel::load(vb, &config)?; + Ok(Self { model, tokenizer }) + } + + pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> { + let device = &Device::Cpu; + let input: DecoderParams = + serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?; + + self.model.clear_kv_cache(); + let sentences = input.sentences; + let normalize_embeddings = input.normalize_embeddings; + let n_sentences = sentences.len(); + let mut all_embeddings = Vec::with_capacity(n_sentences); + for sentence in sentences { + let tokens = self + .tokenizer + .encode(sentence, true) + .map_err(|m| JsError::new(&m.to_string()))? + .get_ids() + .to_vec(); + let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; + let embeddings = self.model.forward(&token_ids)?; + console_log!("generated embeddings {:?}", embeddings.shape()); + // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) + let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?; + let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?; + let embeddings = if normalize_embeddings { + embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)? + } else { + embeddings + }; + console_log!("{:?}", embeddings.shape()); + all_embeddings.push(embeddings.squeeze(0)?.to_vec1::<f32>()?); + } + + Ok(serde_wasm_bindgen::to_value(&DecoderOutput { + embeddings: all_embeddings, + })?) + } +} + +#[derive(serde::Serialize, serde::Deserialize)] +struct ConditionalGenerationOutput { + generation: String, +} + +#[derive(serde::Serialize, serde::Deserialize)] +struct DecoderOutput { + embeddings: Vec<Vec<f32>>, +} + +#[derive(serde::Serialize, serde::Deserialize)] +pub struct DecoderParams { + sentences: Vec<String>, + normalize_embeddings: bool, +} +#[derive(serde::Serialize, serde::Deserialize)] +pub struct ConditionalGenerationParams { + prompt: String, + temperature: f64, + seed: u64, + top_p: f64, + repeat_penalty: f32, + repeat_last_n: usize, + max_length: Option<usize>, +} +fn main() { + console_error_panic_hook::set_once(); +} diff --git a/candle-wasm-examples/t5/src/bin/m.rs b/candle-wasm-examples/t5/src/bin/m.rs new file mode 100644 index 00000000..c82e00cd --- /dev/null +++ b/candle-wasm-examples/t5/src/bin/m.rs @@ -0,0 +1,206 @@ +use candle::{DType, Device, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::generation::LogitsProcessor; +pub use candle_transformers::models::t5::{Config, T5EncoderModel, T5ForConditionalGeneration}; +use candle_wasm_example_t5::console_log; +use tokenizers::Tokenizer; +use wasm_bindgen::prelude::*; +#[wasm_bindgen] +pub struct ModelEncoder { + model: T5EncoderModel, + tokenizer: Tokenizer, +} +#[wasm_bindgen] + +pub struct ModelConditionalGeneration { + model: T5ForConditionalGeneration, + tokenizer: Tokenizer, + config: Config, +} + +#[wasm_bindgen] +impl ModelConditionalGeneration { + #[wasm_bindgen(constructor)] + pub fn load( + weights: Vec<u8>, + tokenizer: Vec<u8>, + config: Vec<u8>, + ) -> Result<ModelConditionalGeneration, JsError> { + console_error_panic_hook::set_once(); + console_log!("loading model"); + let device = &Device::Cpu; + let weights = safetensors::tensor::SafeTensors::deserialize(&weights)?; + let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, device); + let mut config: Config = serde_json::from_slice(&config)?; + let tokenizer = + Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?; + let model = T5ForConditionalGeneration::load(vb, &config)?; + config.use_cache = false; + Ok(Self { + model, + tokenizer, + config, + }) + } + pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> { + let input: ConditionalGenerationParams = + serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?; + let device = &Device::Cpu; + self.model.clear_kv_cache(); + let mut output_token_ids = [self.config.pad_token_id as u32].to_vec(); + let prompt = input.prompt; + let repeat_penalty = input.repeat_penalty; + let repeat_last_n = input.repeat_last_n; + let seed = input.seed; + let max_length = usize::clamp(input.max_length.unwrap_or(512), 0, 512); + let temperature = if input.temperature <= 0. { + None + } else { + Some(input.temperature) + }; + let top_p = if input.top_p <= 0. || input.top_p >= 1. { + None + } else { + Some(input.top_p) + }; + let mut logits_processor = LogitsProcessor::new(seed, temperature, top_p); + let tokens = self + .tokenizer + .encode(prompt, true) + .map_err(|m| JsError::new(&m.to_string()))? + .get_ids() + .to_vec(); + + let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; + let encoder_output = self.model.encode(&input_token_ids)?; + let mut decoded = String::new(); + for index in 0.. { + if output_token_ids.len() > max_length { + break; + } + let decoder_token_ids = if index == 0 { + Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)? + } else { + let last_token = *output_token_ids.last().unwrap(); + Tensor::new(&[last_token], device)?.unsqueeze(0)? + }; + let logits = self + .model + .decode(&decoder_token_ids, &encoder_output)? + .squeeze(0)?; + let logits = if repeat_penalty == 1. { + logits + } else { + let start_at = output_token_ids.len().saturating_sub(repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + repeat_penalty, + &output_token_ids[start_at..], + )? + }; + + let next_token_id = logits_processor.sample(&logits)?; + if next_token_id as usize == self.config.eos_token_id { + break; + } + output_token_ids.push(next_token_id); + if let Some(text) = self.tokenizer.id_to_token(next_token_id) { + let text = text.replace('▁', " ").replace("<0x0A>", "\n"); + decoded += &text; + } + } + Ok(serde_wasm_bindgen::to_value( + &ConditionalGenerationOutput { + generation: decoded, + }, + )?) + } +} + +#[wasm_bindgen] +impl ModelEncoder { + #[wasm_bindgen(constructor)] + pub fn load( + weights: Vec<u8>, + tokenizer: Vec<u8>, + config: Vec<u8>, + ) -> Result<ModelEncoder, JsError> { + console_error_panic_hook::set_once(); + console_log!("loading model"); + let device = &Device::Cpu; + let weights = safetensors::tensor::SafeTensors::deserialize(&weights)?; + let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, device); + let mut config: Config = serde_json::from_slice(&config)?; + config.use_cache = false; + let tokenizer = + Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?; + let model = T5EncoderModel::load(vb, &config)?; + Ok(Self { model, tokenizer }) + } + + pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> { + let device = &Device::Cpu; + let input: DecoderParams = + serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?; + + self.model.clear_kv_cache(); + let sentences = input.sentences; + let normalize_embeddings = input.normalize_embeddings; + let n_sentences = sentences.len(); + let mut all_embeddings = Vec::with_capacity(n_sentences); + for sentence in sentences { + let tokens = self + .tokenizer + .encode(sentence, true) + .map_err(|m| JsError::new(&m.to_string()))? + .get_ids() + .to_vec(); + let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; + let embeddings = self.model.forward(&token_ids)?; + console_log!("generated embeddings {:?}", embeddings.shape()); + // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) + let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?; + let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?; + let embeddings = if normalize_embeddings { + embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)? + } else { + embeddings + }; + console_log!("{:?}", embeddings.shape()); + all_embeddings.push(embeddings.squeeze(0)?.to_vec1::<f32>()?); + } + + Ok(serde_wasm_bindgen::to_value(&DecoderOutput { + embeddings: all_embeddings, + })?) + } +} + +#[derive(serde::Serialize, serde::Deserialize)] +struct ConditionalGenerationOutput { + generation: String, +} + +#[derive(serde::Serialize, serde::Deserialize)] +struct DecoderOutput { + embeddings: Vec<Vec<f32>>, +} + +#[derive(serde::Serialize, serde::Deserialize)] +pub struct DecoderParams { + sentences: Vec<String>, + normalize_embeddings: bool, +} +#[derive(serde::Serialize, serde::Deserialize)] +pub struct ConditionalGenerationParams { + prompt: String, + temperature: f64, + seed: u64, + top_p: f64, + repeat_penalty: f32, + repeat_last_n: usize, + max_length: Option<usize>, +} +fn main() { + console_error_panic_hook::set_once(); +} diff --git a/candle-wasm-examples/t5/src/lib.rs b/candle-wasm-examples/t5/src/lib.rs new file mode 100644 index 00000000..cb15633c --- /dev/null +++ b/candle-wasm-examples/t5/src/lib.rs @@ -0,0 +1,16 @@ +use wasm_bindgen::prelude::*; + +#[wasm_bindgen] +extern "C" { + // Use `js_namespace` here to bind `console.log(..)` instead of just + // `log(..)` + #[wasm_bindgen(js_namespace = console)] + pub fn log(s: &str); +} + +#[macro_export] +macro_rules! console_log { + // Note that this is using the `log` function imported above during + // `bare_bones` + ($($t:tt)*) => ($crate::log(&format_args!($($t)*).to_string())) +} diff --git a/candle-wasm-examples/t5/utils.js b/candle-wasm-examples/t5/utils.js new file mode 100644 index 00000000..e45e7d1b --- /dev/null +++ b/candle-wasm-examples/t5/utils.js @@ -0,0 +1,168 @@ +export async function extractEmbeddings( + worker, + weightsURL, + tokenizerURL, + configURL, + modelID, + sentences, + updateStatus, + normalize_embeddings = true +) { + return new Promise((resolve, reject) => { + worker.postMessage({ + weightsURL, + tokenizerURL, + configURL, + modelID, + sentences, + normalize_embeddings, + }); + function messageHandler(event) { + if ("error" in event.data) { + worker.removeEventListener("message", messageHandler); + reject(new Error(event.data.error)); + } + if (event.data.status === "complete") { + worker.removeEventListener("message", messageHandler); + resolve(event.data); + } + if (updateStatus) updateStatus(event.data); + } + worker.addEventListener("message", messageHandler); + }); +} + +export async function generateText( + worker, + weightsURL, + tokenizerURL, + configURL, + modelID, + prompt, + params, + updateStatus +) { + return new Promise((resolve, reject) => { + worker.postMessage({ + weightsURL, + tokenizerURL, + configURL, + modelID, + prompt, + params, + }); + function messageHandler(event) { + if ("error" in event.data) { + worker.removeEventListener("message", messageHandler); + reject(new Error(event.data.error)); + } + if (event.data.status === "complete") { + worker.removeEventListener("message", messageHandler); + resolve(event.data); + } + if (updateStatus) updateStatus(event.data); + } + worker.addEventListener("message", messageHandler); + }); +} +export const MODELS = { + t5_small_quantized: { + size: "102 MB", + base_url: "https://huggingface.co/lmz/candle-quantized-t5/resolve/main/", + model: "model.gguf", + tokenizer: "tokenizer.json", + config: "config.json", + tasks: { + translation_en_to_de: { + prefix: "translate English to German: ", + max_length: 300, + }, + translation_en_to_fr: { + prefix: "translate English to French: ", + max_length: 300, + }, + translation_en_to_ro: { + prefix: "translate English to Romanian: ", + max_length: 300, + }, + summarization: { prefix: "summarize: ", max_length: 200 }, + }, + }, + t5_small: { + size: "242 MB", + base_url: "https://huggingface.co/t5-small/resolve/main/", + model: "model.safetensors", + tokenizer: "tokenizer.json", + config: "config.json", + tasks: { + translation_en_to_de: { + prefix: "translate English to German: ", + max_length: 300, + }, + translation_en_to_fr: { + prefix: "translate English to French: ", + max_length: 300, + }, + translation_en_to_ro: { + prefix: "translate English to Romanian: ", + max_length: 300, + }, + summarization: { prefix: "summarize: ", max_length: 200 }, + }, + }, + flan_t5_small: { + size: "308 MB", + base_url: + "https://huggingface.co/google/flan-t5-small/resolve/refs%2Fpr%2F14/", + model: "model.safetensors", + tokenizer: "tokenizer.json", + config: "config.json", + tasks: { + translation_en_to_de: { + prefix: "translate English to German: ", + max_length: 300, + }, + translation_en_to_fr: { + prefix: "translate English to French: ", + max_length: 300, + }, + translation_en_to_ro: { + prefix: "translate English to Romanian: ", + max_length: 300, + }, + summarization: { prefix: "summarize: ", max_length: 200 }, + }, + }, + + flan_t5_base_quantized: { + size: "360 MB", + base_url: "https://huggingface.co/lmz/candle-quantized-t5/resolve/main/", + model: "model-flan-t5-base.gguf", + tokenizer: "tokenizer.json", + config: "config-flan-t5-base.json", + tasks: { + translation_en_to_de: { + prefix: "translate English to German: ", + max_length: 300, + }, + translation_en_to_fr: { + prefix: "translate English to French: ", + max_length: 300, + }, + translation_en_to_ro: { + prefix: "translate English to Romanian: ", + max_length: 300, + }, + summarization: { prefix: "summarize: ", max_length: 200 }, + }, + }, +}; +export function getModelInfo(id, taskID) { + const model = MODELS[id]; + return { + modelURL: model.base_url + model.model, + configURL: model.base_url + model.config, + tokenizerURL: model.base_url + model.tokenizer, + maxLength: model.tasks[taskID].max_length, + }; +} |