diff options
author | Radamés Ajna <radamajna@gmail.com> | 2023-09-04 07:59:22 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-04 15:59:22 +0100 |
commit | 8395152d20c7c72fb866ca3f8cbcab8859bfed57 (patch) | |
tree | 73d573ed4b89bcfda0d3c01fa37fa3a6b0f39fb1 /candle-wasm-examples/llama2-c | |
parent | e2f9f60ac2e4ab6b39b8275442f7ad4e76995707 (diff) | |
download | candle-8395152d20c7c72fb866ca3f8cbcab8859bfed57.tar.gz candle-8395152d20c7c72fb866ca3f8cbcab8859bfed57.tar.bz2 candle-8395152d20c7c72fb866ca3f8cbcab8859bfed57.zip |
Llama2c WASM UI improvements (#732)
* pass seed, expose model seq_len
* wip new llama2.c ui
* final new UI example
* small coppy
* copy
Diffstat (limited to 'candle-wasm-examples/llama2-c')
-rw-r--r-- | candle-wasm-examples/llama2-c/README.md | 47 | ||||
-rw-r--r-- | candle-wasm-examples/llama2-c/build-lib.sh | 2 | ||||
-rw-r--r-- | candle-wasm-examples/llama2-c/lib-example.html | 311 | ||||
-rw-r--r-- | candle-wasm-examples/llama2-c/llama2cWorker.js | 96 | ||||
-rw-r--r-- | candle-wasm-examples/llama2-c/src/bin/m.rs | 8 | ||||
-rw-r--r-- | candle-wasm-examples/llama2-c/src/worker.rs | 2 |
6 files changed, 464 insertions, 2 deletions
diff --git a/candle-wasm-examples/llama2-c/README.md b/candle-wasm-examples/llama2-c/README.md new file mode 100644 index 00000000..0b41e064 --- /dev/null +++ b/candle-wasm-examples/llama2-c/README.md @@ -0,0 +1,47 @@ +## Running [llama2.c](https://github.com/karpathy/llama2.c) Examples + +Here, we provide two examples of how to run [llama2.c](https://github.com/karpathy/llama2.c) written in Rust using a Candle-compiled WASM binary and runtimes. + +### Pure Rust UI + +To build and test the UI made in Rust you will need [Trunk](https://trunkrs.dev/#install) +From the `candle-wasm-examples/llama2-c` directory run: + +Download assets: + +```bash +# Model and tokenizer + +wget -c https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/model.bin +wget -c https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/tokenizer.json + +``` + +Run hot reload server: + +```bash +trunk serve --release --public-url / --port 8080 +``` + +### Vanilla JS and WebWorkers + +To build and test the UI made in Vanilla JS and WebWorkers, first we need to build the WASM library: + +```bash +sh build-lib.sh +``` + +This will bundle the library under `./build` and we can import it inside our WebWorker like a normal JS module: + +```js +import init, { Model } from "./build/m.js"; +``` + +The full example can be found under `./lib-example.html`. All needed assets are fetched from the web, so no need to download anything. +Finally, you can preview the example by running a local HTTP server. For example: + +```bash +python -m http.server +``` + +Then open `http://localhost:8000/lib-example.html` in your browser. diff --git a/candle-wasm-examples/llama2-c/build-lib.sh b/candle-wasm-examples/llama2-c/build-lib.sh new file mode 100644 index 00000000..b0ebb182 --- /dev/null +++ b/candle-wasm-examples/llama2-c/build-lib.sh @@ -0,0 +1,2 @@ +cargo build --target wasm32-unknown-unknown --release +wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web diff --git a/candle-wasm-examples/llama2-c/lib-example.html b/candle-wasm-examples/llama2-c/lib-example.html new file mode 100644 index 00000000..bc519e4b --- /dev/null +++ b/candle-wasm-examples/llama2-c/lib-example.html @@ -0,0 +1,311 @@ +<html> + <head> + <meta content="text/html;charset=utf-8" http-equiv="Content-Type" /> + <title>Candle Llama.c Rust/WASM</title> + </head> + <body></body> +</html> + +<!doctype html> +<html> + <head> + <meta charset="UTF-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + <style> + @import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap"); + html, + body { + font-family: "Source Sans 3", sans-serif; + } + code, + output, + select, + pre { + font-family: "Source Code Pro", monospace; + } + </style> + <script src="https://cdn.tailwindcss.com"></script> + <script type="module"> + // base url for audio examples + const MODELS_BASE_URL = + "https://huggingface.co/karpathy/tinyllamas/resolve/main"; + + // models base url + const MODELS = { + stories15M: { + url: "stories15M.bin", + seq_len: 256, + }, + stories42M: { + url: "stories42M.bin", + seq_len: 256, + }, + stories110M: { + url: "stories110M.bin", + seq_len: 256, + }, + }; + + const llamaWorker = new Worker("./llama2cWorker.js", { + type: "module", + }); + async function generateSequence(controller) { + const getValue = (id) => document.querySelector(`#${id}`).value; + const modelID = getValue("model"); + const model = MODELS[modelID]; + const weightsURL = `${MODELS_BASE_URL}/${model.url}`; + const prompt = getValue("prompt"); + const temperature = getValue("temperature"); + const repeatPenalty = getValue("repeat_penalty"); + const seed = getValue("seed"); + const maxSeqLen = getValue("max-seq"); + + function updateStatus({ status, message, prompt, sentence }) { + const outStatus = document.querySelector("#output-status"); + const outGen = document.querySelector("#output-generation"); + + switch (status) { + case "loading": + outStatus.hidden = false; + outStatus.textContent = message; + outGen.hidden = true; + break; + case "generating": + outStatus.hidden = true; + outGen.hidden = false; + outGen.innerHTML = `<span class="font-semibold">${prompt}</span>${sentence.replace( + /\<s\>|\<\/s\>/g, + "" + )}`; + break; + case "complete": + outStatus.hidden = true; + outGen.hidden = false; + break; + } + } + + return new Promise((resolve, reject) => { + llamaWorker.postMessage({ + weightsURL, + modelID, + tokenizerURL: "tokenizer.json", + prompt, + temp: temperature, + repeatPenalty, + seed: BigInt(seed), + maxSeqLen, + command: "start", + }); + + const handleAbort = () => { + llamaWorker.postMessage({ command: "abort" }); + }; + const handleMessage = (event) => { + const { status, error, message, prompt, sentence } = event.data; + if (status) updateStatus(event.data); + if (error) reject(new Error(error)); + if (status === "complete") resolve(event.data); + }; + + controller.signal.addEventListener("abort", handleAbort); + llamaWorker.addEventListener("message", handleMessage); + }); + } + + const form = document.querySelector("#form"); + const prompt = document.querySelector("#prompt"); + const clearBtn = document.querySelector("#clear-btn"); + const runBtn = document.querySelector("#run"); + let runController = new AbortController(); + let isRunning = false; + + form.addEventListener("submit", async (e) => { + e.preventDefault(); + if (isRunning) { + stopRunning(); + } else { + startRunning(); + await generateSequence(runController); + stopRunning(); + } + }); + + function startRunning() { + isRunning = true; + runBtn.textContent = "Stop"; + } + + function stopRunning() { + runController.abort(); + runController = new AbortController(); + runBtn.textContent = "Run"; + isRunning = false; + } + clearBtn.addEventListener("click", (e) => { + e.preventDefault(); + prompt.value = ""; + clearBtn.classList.add("invisible"); + runBtn.disabled = true; + stopRunning(); + }); + prompt.addEventListener("input", (e) => { + runBtn.disabled = false; + if (e.target.value.length > 0) { + clearBtn.classList.remove("invisible"); + } else { + clearBtn.classList.add("invisible"); + } + }); + </script> + </head> + <body class="container max-w-4xl mx-auto p-4 text-gray-800"> + <main class="grid grid-cols-1 gap-8 relative"> + <span class="absolute text-5xl -ml-[1em]"> 🕯️ </span> + <div> + <h1 class="text-5xl font-bold">Candle Llama2.c</h1> + <h2 class="text-2xl font-bold">Rust/WASM Demo</h2> + <p class="max-w-lg"> + <a + href="https://github.com/karpathy/llama2.c" + target="_blank" + class="underline hover:text-blue-500 hover:no-underline" + target="_blank" + >Llama2.c</a + > + is Andrey Karpathy's C implementation of the Llama 2 LLM model in C. + This demo uses + <a + href="https://github.com/huggingface/candle/" + target="_blank" + class="underline hover:text-blue-500 hover:no-underline" + >Candle + </a> + to run Llama2.c in the browser using rust/wasm. + </p> + </div> + + <div> + <label for="model" class="font-medium">Models Options: </label> + <select + id="model" + class="border-2 border-gray-500 rounded-md font-light" + > + <option value="stories15M" selected>stories 15M (60.8 MB)</option> + <option value="stories42M">stories 42M (167 MB)</option> + <option value="stories110M">stories 110M (438 MB)</option> + </select> + </div> + <form + id="form" + class="flex text-normal px-1 py-1 border border-gray-700 rounded-md items-center" + > + <input type="submit" hidden /> + <input + type="text" + id="prompt" + class="font-light w-full px-3 py-2 mx-1 resize-none outline-none" + placeholder="Add your prompt here..." + /> + <button class="invisible" id="clear-btn"> + <svg + fill="none" + xmlns="http://www.w3.org/2000/svg" + width="40" + viewBox="0 0 70 40" + > + <path opacity=".5" d="M39 .2v40.2" stroke="#1F2937" /> + <path + d="M1.5 11.5 19 29.1m0-17.6L1.5 29.1" + opacity=".5" + stroke="#1F2937" + stroke-width="2" + /> + </svg> + </button> + <button + id="run" + disabled + class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 w-16 rounded disabled:bg-gray-300 disabled:cursor-not-allowed" + > + Run + </button> + </form> + <div class="grid grid-cols-3 max-w-md items-center gap-3"> + <label class="text-sm font-medium" for="max-seq">Maximum length </label> + <input + type="range" + id="max-seq" + name="temperature" + min="1" + max="256" + step="1" + value="200" + oninput="this.nextElementSibling.value = Number(this.value)" + /> + <output + class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md" + > + 200</output + > + <label class="text-sm font-medium" for="temperature">Temperature</label> + <input + type="range" + id="temperature" + name="temperature" + min="0" + max="2" + step="0.01" + value="0.50" + oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" + /> + <output + class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md" + > + 0.50</output + > + + <label class="text-sm font-medium" for="repeat_penalty" + >Repeat Penalty</label + > + + <input + type="range" + id="repeat_penalty" + name="repeat_penalty" + min="-2" + max="2" + step="0.01" + value="1.10" + oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" + /> + <output + class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md" + >1.10</output + > + <label class="text-sm font-medium" for="seed">Seed</label> + <input + type="number" + id="seed" + name="seed" + value="299792458" + class="font-light border border-gray-700 text-right rounded-md p-2" + /> + </div> + <div> + <h3 class="font-medium">Generation:</h3> + + <div + class="min-h-[250px] bg-slate-100 text-gray-500 p-4 rounded-md grid" + > + <p hidden id="output-generation"></p> + <span + id="output-status" + class="justify-self-center self-center font-light" + >No output yet</span + > + </div> + </div> + </main> + </body> +</html> diff --git a/candle-wasm-examples/llama2-c/llama2cWorker.js b/candle-wasm-examples/llama2-c/llama2cWorker.js new file mode 100644 index 00000000..ba303aaa --- /dev/null +++ b/candle-wasm-examples/llama2-c/llama2cWorker.js @@ -0,0 +1,96 @@ +import init, { Model } from "./build/m.js"; + +async function fetchArrayBuffer(url) { + const res = await fetch(url, { + cache: "force-cache", + }); + const data = await res.arrayBuffer(); + return new Uint8Array(data); +} + +class Llama2C { + static instance = {}; + + static async getInstance(weightsURL, modelID, tokenizerURL) { + // load individual modelID only once + if (!this.instance[modelID]) { + await init(); + + self.postMessage({ status: "loading", message: "Loading Model" }); + + const [weightsArrayU8, tokenizerArrayU8] = await Promise.all([ + fetchArrayBuffer(weightsURL), + fetchArrayBuffer(tokenizerURL), + ]); + + this.instance[modelID] = new Model(weightsArrayU8, tokenizerArrayU8); + } + return this.instance[modelID]; + } +} + +let controller = null; +self.addEventListener("message", (event) => { + if (event.data.command === "start") { + controller = new AbortController(); + generate(event.data); + } else if (event.data.command === "abort") { + controller.abort(); + } +}); + +async function generate(data) { + const { + weightsURL, + modelID, + tokenizerURL, + prompt, + temp, + repeatPenalty, + seed, + maxSeqLen, + } = data; + try { + self.postMessage({ status: "loading", message: "Starting llama2.c" }); + const model = await Llama2C.getInstance(weightsURL, modelID, tokenizerURL); + + self.postMessage({ status: "loading", message: "Initializing model" }); + model.init_with_prompt(prompt, temp, repeatPenalty, seed); + + const seq_len = model.get_seq_len(); + + let sentence = ""; + let max_tokens = maxSeqLen ? maxSeqLen : seq_len - prompt.length - 1; + + while (max_tokens--) { + await new Promise(async (resolve) => { + if (controller && controller.signal.aborted) { + self.postMessage({ + status: "aborted", + message: "Aborted", + output: prompt + sentence, + }); + return; + } + const token = await model.next_token(); + + sentence += token; + self.postMessage({ + status: "generating", + message: "Generating token", + token: token, + sentence: sentence, + prompt: prompt, + }); + setTimeout(resolve, 0); + }); + } + self.postMessage({ + status: "complete", + message: "complete", + output: prompt + sentence, + }); + } catch (e) { + self.postMessage({ error: e }); + } +} diff --git a/candle-wasm-examples/llama2-c/src/bin/m.rs b/candle-wasm-examples/llama2-c/src/bin/m.rs index da71f071..62b1bdf7 100644 --- a/candle-wasm-examples/llama2-c/src/bin/m.rs +++ b/candle-wasm-examples/llama2-c/src/bin/m.rs @@ -58,6 +58,11 @@ impl Model { Err(e) => Err(JsError::new(&e.to_string())), } } + #[wasm_bindgen] + pub fn get_seq_len(&mut self) -> usize { + let seq_len = self.inner.config.seq_len; + seq_len + } #[wasm_bindgen] pub fn init_with_prompt( @@ -65,6 +70,7 @@ impl Model { prompt: String, temp: f64, repeat_penalty: f32, + seed: u64, ) -> Result<String, JsError> { // First reset the cache. { @@ -74,7 +80,7 @@ impl Model { } } let temp = if temp <= 0. { None } else { Some(temp) }; - self.logits_processor = LogitsProcessor::new(299792458, temp); + self.logits_processor = LogitsProcessor::new(seed, temp); self.repeat_penalty = repeat_penalty; self.tokens.clear(); let tokens = self diff --git a/candle-wasm-examples/llama2-c/src/worker.rs b/candle-wasm-examples/llama2-c/src/worker.rs index 3d187fcc..7e97b5da 100644 --- a/candle-wasm-examples/llama2-c/src/worker.rs +++ b/candle-wasm-examples/llama2-c/src/worker.rs @@ -51,7 +51,7 @@ fn read_tensor<R: std::io::Read, S: Into<Shape>>( pub struct Model { pub cache: Cache, - config: Config, + pub config: Config, pub llama: Llama, pub tokenizer: Tokenizer, } |