summaryrefslogtreecommitdiff
path: root/candle-wasm-examples/llama2-c
diff options
context:
space:
mode:
authorRadamés Ajna <radamajna@gmail.com>2023-09-04 07:59:22 -0700
committerGitHub <noreply@github.com>2023-09-04 15:59:22 +0100
commit8395152d20c7c72fb866ca3f8cbcab8859bfed57 (patch)
tree73d573ed4b89bcfda0d3c01fa37fa3a6b0f39fb1 /candle-wasm-examples/llama2-c
parente2f9f60ac2e4ab6b39b8275442f7ad4e76995707 (diff)
downloadcandle-8395152d20c7c72fb866ca3f8cbcab8859bfed57.tar.gz
candle-8395152d20c7c72fb866ca3f8cbcab8859bfed57.tar.bz2
candle-8395152d20c7c72fb866ca3f8cbcab8859bfed57.zip
Llama2c WASM UI improvements (#732)
* pass seed, expose model seq_len * wip new llama2.c ui * final new UI example * small coppy * copy
Diffstat (limited to 'candle-wasm-examples/llama2-c')
-rw-r--r--candle-wasm-examples/llama2-c/README.md47
-rw-r--r--candle-wasm-examples/llama2-c/build-lib.sh2
-rw-r--r--candle-wasm-examples/llama2-c/lib-example.html311
-rw-r--r--candle-wasm-examples/llama2-c/llama2cWorker.js96
-rw-r--r--candle-wasm-examples/llama2-c/src/bin/m.rs8
-rw-r--r--candle-wasm-examples/llama2-c/src/worker.rs2
6 files changed, 464 insertions, 2 deletions
diff --git a/candle-wasm-examples/llama2-c/README.md b/candle-wasm-examples/llama2-c/README.md
new file mode 100644
index 00000000..0b41e064
--- /dev/null
+++ b/candle-wasm-examples/llama2-c/README.md
@@ -0,0 +1,47 @@
+## Running [llama2.c](https://github.com/karpathy/llama2.c) Examples
+
+Here, we provide two examples of how to run [llama2.c](https://github.com/karpathy/llama2.c) written in Rust using a Candle-compiled WASM binary and runtimes.
+
+### Pure Rust UI
+
+To build and test the UI made in Rust you will need [Trunk](https://trunkrs.dev/#install)
+From the `candle-wasm-examples/llama2-c` directory run:
+
+Download assets:
+
+```bash
+# Model and tokenizer
+
+wget -c https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/model.bin
+wget -c https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/tokenizer.json
+
+```
+
+Run hot reload server:
+
+```bash
+trunk serve --release --public-url / --port 8080
+```
+
+### Vanilla JS and WebWorkers
+
+To build and test the UI made in Vanilla JS and WebWorkers, first we need to build the WASM library:
+
+```bash
+sh build-lib.sh
+```
+
+This will bundle the library under `./build` and we can import it inside our WebWorker like a normal JS module:
+
+```js
+import init, { Model } from "./build/m.js";
+```
+
+The full example can be found under `./lib-example.html`. All needed assets are fetched from the web, so no need to download anything.
+Finally, you can preview the example by running a local HTTP server. For example:
+
+```bash
+python -m http.server
+```
+
+Then open `http://localhost:8000/lib-example.html` in your browser.
diff --git a/candle-wasm-examples/llama2-c/build-lib.sh b/candle-wasm-examples/llama2-c/build-lib.sh
new file mode 100644
index 00000000..b0ebb182
--- /dev/null
+++ b/candle-wasm-examples/llama2-c/build-lib.sh
@@ -0,0 +1,2 @@
+cargo build --target wasm32-unknown-unknown --release
+wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web
diff --git a/candle-wasm-examples/llama2-c/lib-example.html b/candle-wasm-examples/llama2-c/lib-example.html
new file mode 100644
index 00000000..bc519e4b
--- /dev/null
+++ b/candle-wasm-examples/llama2-c/lib-example.html
@@ -0,0 +1,311 @@
+<html>
+ <head>
+ <meta content="text/html;charset=utf-8" http-equiv="Content-Type" />
+ <title>Candle Llama.c Rust/WASM</title>
+ </head>
+ <body></body>
+</html>
+
+<!doctype html>
+<html>
+ <head>
+ <meta charset="UTF-8" />
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+ <style>
+ @import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap");
+ html,
+ body {
+ font-family: "Source Sans 3", sans-serif;
+ }
+ code,
+ output,
+ select,
+ pre {
+ font-family: "Source Code Pro", monospace;
+ }
+ </style>
+ <script src="https://cdn.tailwindcss.com"></script>
+ <script type="module">
+ // base url for audio examples
+ const MODELS_BASE_URL =
+ "https://huggingface.co/karpathy/tinyllamas/resolve/main";
+
+ // models base url
+ const MODELS = {
+ stories15M: {
+ url: "stories15M.bin",
+ seq_len: 256,
+ },
+ stories42M: {
+ url: "stories42M.bin",
+ seq_len: 256,
+ },
+ stories110M: {
+ url: "stories110M.bin",
+ seq_len: 256,
+ },
+ };
+
+ const llamaWorker = new Worker("./llama2cWorker.js", {
+ type: "module",
+ });
+ async function generateSequence(controller) {
+ const getValue = (id) => document.querySelector(`#${id}`).value;
+ const modelID = getValue("model");
+ const model = MODELS[modelID];
+ const weightsURL = `${MODELS_BASE_URL}/${model.url}`;
+ const prompt = getValue("prompt");
+ const temperature = getValue("temperature");
+ const repeatPenalty = getValue("repeat_penalty");
+ const seed = getValue("seed");
+ const maxSeqLen = getValue("max-seq");
+
+ function updateStatus({ status, message, prompt, sentence }) {
+ const outStatus = document.querySelector("#output-status");
+ const outGen = document.querySelector("#output-generation");
+
+ switch (status) {
+ case "loading":
+ outStatus.hidden = false;
+ outStatus.textContent = message;
+ outGen.hidden = true;
+ break;
+ case "generating":
+ outStatus.hidden = true;
+ outGen.hidden = false;
+ outGen.innerHTML = `<span class="font-semibold">${prompt}</span>${sentence.replace(
+ /\<s\>|\<\/s\>/g,
+ ""
+ )}`;
+ break;
+ case "complete":
+ outStatus.hidden = true;
+ outGen.hidden = false;
+ break;
+ }
+ }
+
+ return new Promise((resolve, reject) => {
+ llamaWorker.postMessage({
+ weightsURL,
+ modelID,
+ tokenizerURL: "tokenizer.json",
+ prompt,
+ temp: temperature,
+ repeatPenalty,
+ seed: BigInt(seed),
+ maxSeqLen,
+ command: "start",
+ });
+
+ const handleAbort = () => {
+ llamaWorker.postMessage({ command: "abort" });
+ };
+ const handleMessage = (event) => {
+ const { status, error, message, prompt, sentence } = event.data;
+ if (status) updateStatus(event.data);
+ if (error) reject(new Error(error));
+ if (status === "complete") resolve(event.data);
+ };
+
+ controller.signal.addEventListener("abort", handleAbort);
+ llamaWorker.addEventListener("message", handleMessage);
+ });
+ }
+
+ const form = document.querySelector("#form");
+ const prompt = document.querySelector("#prompt");
+ const clearBtn = document.querySelector("#clear-btn");
+ const runBtn = document.querySelector("#run");
+ let runController = new AbortController();
+ let isRunning = false;
+
+ form.addEventListener("submit", async (e) => {
+ e.preventDefault();
+ if (isRunning) {
+ stopRunning();
+ } else {
+ startRunning();
+ await generateSequence(runController);
+ stopRunning();
+ }
+ });
+
+ function startRunning() {
+ isRunning = true;
+ runBtn.textContent = "Stop";
+ }
+
+ function stopRunning() {
+ runController.abort();
+ runController = new AbortController();
+ runBtn.textContent = "Run";
+ isRunning = false;
+ }
+ clearBtn.addEventListener("click", (e) => {
+ e.preventDefault();
+ prompt.value = "";
+ clearBtn.classList.add("invisible");
+ runBtn.disabled = true;
+ stopRunning();
+ });
+ prompt.addEventListener("input", (e) => {
+ runBtn.disabled = false;
+ if (e.target.value.length > 0) {
+ clearBtn.classList.remove("invisible");
+ } else {
+ clearBtn.classList.add("invisible");
+ }
+ });
+ </script>
+ </head>
+ <body class="container max-w-4xl mx-auto p-4 text-gray-800">
+ <main class="grid grid-cols-1 gap-8 relative">
+ <span class="absolute text-5xl -ml-[1em]"> 🕯️ </span>
+ <div>
+ <h1 class="text-5xl font-bold">Candle Llama2.c</h1>
+ <h2 class="text-2xl font-bold">Rust/WASM Demo</h2>
+ <p class="max-w-lg">
+ <a
+ href="https://github.com/karpathy/llama2.c"
+ target="_blank"
+ class="underline hover:text-blue-500 hover:no-underline"
+ target="_blank"
+ >Llama2.c</a
+ >
+ is Andrey Karpathy's C implementation of the Llama 2 LLM model in C.
+ This demo uses
+ <a
+ href="https://github.com/huggingface/candle/"
+ target="_blank"
+ class="underline hover:text-blue-500 hover:no-underline"
+ >Candle
+ </a>
+ to run Llama2.c in the browser using rust/wasm.
+ </p>
+ </div>
+
+ <div>
+ <label for="model" class="font-medium">Models Options: </label>
+ <select
+ id="model"
+ class="border-2 border-gray-500 rounded-md font-light"
+ >
+ <option value="stories15M" selected>stories 15M (60.8 MB)</option>
+ <option value="stories42M">stories 42M (167 MB)</option>
+ <option value="stories110M">stories 110M (438 MB)</option>
+ </select>
+ </div>
+ <form
+ id="form"
+ class="flex text-normal px-1 py-1 border border-gray-700 rounded-md items-center"
+ >
+ <input type="submit" hidden />
+ <input
+ type="text"
+ id="prompt"
+ class="font-light w-full px-3 py-2 mx-1 resize-none outline-none"
+ placeholder="Add your prompt here..."
+ />
+ <button class="invisible" id="clear-btn">
+ <svg
+ fill="none"
+ xmlns="http://www.w3.org/2000/svg"
+ width="40"
+ viewBox="0 0 70 40"
+ >
+ <path opacity=".5" d="M39 .2v40.2" stroke="#1F2937" />
+ <path
+ d="M1.5 11.5 19 29.1m0-17.6L1.5 29.1"
+ opacity=".5"
+ stroke="#1F2937"
+ stroke-width="2"
+ />
+ </svg>
+ </button>
+ <button
+ id="run"
+ disabled
+ class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 w-16 rounded disabled:bg-gray-300 disabled:cursor-not-allowed"
+ >
+ Run
+ </button>
+ </form>
+ <div class="grid grid-cols-3 max-w-md items-center gap-3">
+ <label class="text-sm font-medium" for="max-seq">Maximum length </label>
+ <input
+ type="range"
+ id="max-seq"
+ name="temperature"
+ min="1"
+ max="256"
+ step="1"
+ value="200"
+ oninput="this.nextElementSibling.value = Number(this.value)"
+ />
+ <output
+ class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
+ >
+ 200</output
+ >
+ <label class="text-sm font-medium" for="temperature">Temperature</label>
+ <input
+ type="range"
+ id="temperature"
+ name="temperature"
+ min="0"
+ max="2"
+ step="0.01"
+ value="0.50"
+ oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
+ />
+ <output
+ class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
+ >
+ 0.50</output
+ >
+
+ <label class="text-sm font-medium" for="repeat_penalty"
+ >Repeat Penalty</label
+ >
+
+ <input
+ type="range"
+ id="repeat_penalty"
+ name="repeat_penalty"
+ min="-2"
+ max="2"
+ step="0.01"
+ value="1.10"
+ oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
+ />
+ <output
+ class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
+ >1.10</output
+ >
+ <label class="text-sm font-medium" for="seed">Seed</label>
+ <input
+ type="number"
+ id="seed"
+ name="seed"
+ value="299792458"
+ class="font-light border border-gray-700 text-right rounded-md p-2"
+ />
+ </div>
+ <div>
+ <h3 class="font-medium">Generation:</h3>
+
+ <div
+ class="min-h-[250px] bg-slate-100 text-gray-500 p-4 rounded-md grid"
+ >
+ <p hidden id="output-generation"></p>
+ <span
+ id="output-status"
+ class="justify-self-center self-center font-light"
+ >No output yet</span
+ >
+ </div>
+ </div>
+ </main>
+ </body>
+</html>
diff --git a/candle-wasm-examples/llama2-c/llama2cWorker.js b/candle-wasm-examples/llama2-c/llama2cWorker.js
new file mode 100644
index 00000000..ba303aaa
--- /dev/null
+++ b/candle-wasm-examples/llama2-c/llama2cWorker.js
@@ -0,0 +1,96 @@
+import init, { Model } from "./build/m.js";
+
+async function fetchArrayBuffer(url) {
+ const res = await fetch(url, {
+ cache: "force-cache",
+ });
+ const data = await res.arrayBuffer();
+ return new Uint8Array(data);
+}
+
+class Llama2C {
+ static instance = {};
+
+ static async getInstance(weightsURL, modelID, tokenizerURL) {
+ // load individual modelID only once
+ if (!this.instance[modelID]) {
+ await init();
+
+ self.postMessage({ status: "loading", message: "Loading Model" });
+
+ const [weightsArrayU8, tokenizerArrayU8] = await Promise.all([
+ fetchArrayBuffer(weightsURL),
+ fetchArrayBuffer(tokenizerURL),
+ ]);
+
+ this.instance[modelID] = new Model(weightsArrayU8, tokenizerArrayU8);
+ }
+ return this.instance[modelID];
+ }
+}
+
+let controller = null;
+self.addEventListener("message", (event) => {
+ if (event.data.command === "start") {
+ controller = new AbortController();
+ generate(event.data);
+ } else if (event.data.command === "abort") {
+ controller.abort();
+ }
+});
+
+async function generate(data) {
+ const {
+ weightsURL,
+ modelID,
+ tokenizerURL,
+ prompt,
+ temp,
+ repeatPenalty,
+ seed,
+ maxSeqLen,
+ } = data;
+ try {
+ self.postMessage({ status: "loading", message: "Starting llama2.c" });
+ const model = await Llama2C.getInstance(weightsURL, modelID, tokenizerURL);
+
+ self.postMessage({ status: "loading", message: "Initializing model" });
+ model.init_with_prompt(prompt, temp, repeatPenalty, seed);
+
+ const seq_len = model.get_seq_len();
+
+ let sentence = "";
+ let max_tokens = maxSeqLen ? maxSeqLen : seq_len - prompt.length - 1;
+
+ while (max_tokens--) {
+ await new Promise(async (resolve) => {
+ if (controller && controller.signal.aborted) {
+ self.postMessage({
+ status: "aborted",
+ message: "Aborted",
+ output: prompt + sentence,
+ });
+ return;
+ }
+ const token = await model.next_token();
+
+ sentence += token;
+ self.postMessage({
+ status: "generating",
+ message: "Generating token",
+ token: token,
+ sentence: sentence,
+ prompt: prompt,
+ });
+ setTimeout(resolve, 0);
+ });
+ }
+ self.postMessage({
+ status: "complete",
+ message: "complete",
+ output: prompt + sentence,
+ });
+ } catch (e) {
+ self.postMessage({ error: e });
+ }
+}
diff --git a/candle-wasm-examples/llama2-c/src/bin/m.rs b/candle-wasm-examples/llama2-c/src/bin/m.rs
index da71f071..62b1bdf7 100644
--- a/candle-wasm-examples/llama2-c/src/bin/m.rs
+++ b/candle-wasm-examples/llama2-c/src/bin/m.rs
@@ -58,6 +58,11 @@ impl Model {
Err(e) => Err(JsError::new(&e.to_string())),
}
}
+ #[wasm_bindgen]
+ pub fn get_seq_len(&mut self) -> usize {
+ let seq_len = self.inner.config.seq_len;
+ seq_len
+ }
#[wasm_bindgen]
pub fn init_with_prompt(
@@ -65,6 +70,7 @@ impl Model {
prompt: String,
temp: f64,
repeat_penalty: f32,
+ seed: u64,
) -> Result<String, JsError> {
// First reset the cache.
{
@@ -74,7 +80,7 @@ impl Model {
}
}
let temp = if temp <= 0. { None } else { Some(temp) };
- self.logits_processor = LogitsProcessor::new(299792458, temp);
+ self.logits_processor = LogitsProcessor::new(seed, temp);
self.repeat_penalty = repeat_penalty;
self.tokens.clear();
let tokens = self
diff --git a/candle-wasm-examples/llama2-c/src/worker.rs b/candle-wasm-examples/llama2-c/src/worker.rs
index 3d187fcc..7e97b5da 100644
--- a/candle-wasm-examples/llama2-c/src/worker.rs
+++ b/candle-wasm-examples/llama2-c/src/worker.rs
@@ -51,7 +51,7 @@ fn read_tensor<R: std::io::Read, S: Into<Shape>>(
pub struct Model {
pub cache: Cache,
- config: Config,
+ pub config: Config,
pub llama: Llama,
pub tokenizer: Tokenizer,
}