From 8395152d20c7c72fb866ca3f8cbcab8859bfed57 Mon Sep 17 00:00:00 2001 From: Radamés Ajna Date: Mon, 4 Sep 2023 07:59:22 -0700 Subject: Llama2c WASM UI improvements (#732) * pass seed, expose model seq_len * wip new llama2.c ui * final new UI example * small coppy * copy --- candle-wasm-examples/llama2-c/README.md | 47 ++++ candle-wasm-examples/llama2-c/build-lib.sh | 2 + candle-wasm-examples/llama2-c/lib-example.html | 311 +++++++++++++++++++++++++ candle-wasm-examples/llama2-c/llama2cWorker.js | 96 ++++++++ candle-wasm-examples/llama2-c/src/bin/m.rs | 8 +- candle-wasm-examples/llama2-c/src/worker.rs | 2 +- 6 files changed, 464 insertions(+), 2 deletions(-) create mode 100644 candle-wasm-examples/llama2-c/README.md create mode 100644 candle-wasm-examples/llama2-c/build-lib.sh create mode 100644 candle-wasm-examples/llama2-c/lib-example.html create mode 100644 candle-wasm-examples/llama2-c/llama2cWorker.js (limited to 'candle-wasm-examples/llama2-c') diff --git a/candle-wasm-examples/llama2-c/README.md b/candle-wasm-examples/llama2-c/README.md new file mode 100644 index 00000000..0b41e064 --- /dev/null +++ b/candle-wasm-examples/llama2-c/README.md @@ -0,0 +1,47 @@ +## Running [llama2.c](https://github.com/karpathy/llama2.c) Examples + +Here, we provide two examples of how to run [llama2.c](https://github.com/karpathy/llama2.c) written in Rust using a Candle-compiled WASM binary and runtimes. + +### Pure Rust UI + +To build and test the UI made in Rust you will need [Trunk](https://trunkrs.dev/#install) +From the `candle-wasm-examples/llama2-c` directory run: + +Download assets: + +```bash +# Model and tokenizer + +wget -c https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/model.bin +wget -c https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/tokenizer.json + +``` + +Run hot reload server: + +```bash +trunk serve --release --public-url / --port 8080 +``` + +### Vanilla JS and WebWorkers + +To build and test the UI made in Vanilla JS and WebWorkers, first we need to build the WASM library: + +```bash +sh build-lib.sh +``` + +This will bundle the library under `./build` and we can import it inside our WebWorker like a normal JS module: + +```js +import init, { Model } from "./build/m.js"; +``` + +The full example can be found under `./lib-example.html`. All needed assets are fetched from the web, so no need to download anything. +Finally, you can preview the example by running a local HTTP server. For example: + +```bash +python -m http.server +``` + +Then open `http://localhost:8000/lib-example.html` in your browser. diff --git a/candle-wasm-examples/llama2-c/build-lib.sh b/candle-wasm-examples/llama2-c/build-lib.sh new file mode 100644 index 00000000..b0ebb182 --- /dev/null +++ b/candle-wasm-examples/llama2-c/build-lib.sh @@ -0,0 +1,2 @@ +cargo build --target wasm32-unknown-unknown --release +wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web diff --git a/candle-wasm-examples/llama2-c/lib-example.html b/candle-wasm-examples/llama2-c/lib-example.html new file mode 100644 index 00000000..bc519e4b --- /dev/null +++ b/candle-wasm-examples/llama2-c/lib-example.html @@ -0,0 +1,311 @@ + + + + Candle Llama.c Rust/WASM + + + + + + + + + + + + + + +
+ 🕯️ +
+

Candle Llama2.c

+

Rust/WASM Demo

+

+ Llama2.c + is Andrey Karpathy's C implementation of the Llama 2 LLM model in C. + This demo uses + Candle + + to run Llama2.c in the browser using rust/wasm. +

+
+ +
+ + +
+
+ + + + +
+
+ + + + 200 + + + + 0.50 + + + + + 1.10 + + +
+
+

Generation:

+ +
+ + No output yet +
+
+
+ + diff --git a/candle-wasm-examples/llama2-c/llama2cWorker.js b/candle-wasm-examples/llama2-c/llama2cWorker.js new file mode 100644 index 00000000..ba303aaa --- /dev/null +++ b/candle-wasm-examples/llama2-c/llama2cWorker.js @@ -0,0 +1,96 @@ +import init, { Model } from "./build/m.js"; + +async function fetchArrayBuffer(url) { + const res = await fetch(url, { + cache: "force-cache", + }); + const data = await res.arrayBuffer(); + return new Uint8Array(data); +} + +class Llama2C { + static instance = {}; + + static async getInstance(weightsURL, modelID, tokenizerURL) { + // load individual modelID only once + if (!this.instance[modelID]) { + await init(); + + self.postMessage({ status: "loading", message: "Loading Model" }); + + const [weightsArrayU8, tokenizerArrayU8] = await Promise.all([ + fetchArrayBuffer(weightsURL), + fetchArrayBuffer(tokenizerURL), + ]); + + this.instance[modelID] = new Model(weightsArrayU8, tokenizerArrayU8); + } + return this.instance[modelID]; + } +} + +let controller = null; +self.addEventListener("message", (event) => { + if (event.data.command === "start") { + controller = new AbortController(); + generate(event.data); + } else if (event.data.command === "abort") { + controller.abort(); + } +}); + +async function generate(data) { + const { + weightsURL, + modelID, + tokenizerURL, + prompt, + temp, + repeatPenalty, + seed, + maxSeqLen, + } = data; + try { + self.postMessage({ status: "loading", message: "Starting llama2.c" }); + const model = await Llama2C.getInstance(weightsURL, modelID, tokenizerURL); + + self.postMessage({ status: "loading", message: "Initializing model" }); + model.init_with_prompt(prompt, temp, repeatPenalty, seed); + + const seq_len = model.get_seq_len(); + + let sentence = ""; + let max_tokens = maxSeqLen ? maxSeqLen : seq_len - prompt.length - 1; + + while (max_tokens--) { + await new Promise(async (resolve) => { + if (controller && controller.signal.aborted) { + self.postMessage({ + status: "aborted", + message: "Aborted", + output: prompt + sentence, + }); + return; + } + const token = await model.next_token(); + + sentence += token; + self.postMessage({ + status: "generating", + message: "Generating token", + token: token, + sentence: sentence, + prompt: prompt, + }); + setTimeout(resolve, 0); + }); + } + self.postMessage({ + status: "complete", + message: "complete", + output: prompt + sentence, + }); + } catch (e) { + self.postMessage({ error: e }); + } +} diff --git a/candle-wasm-examples/llama2-c/src/bin/m.rs b/candle-wasm-examples/llama2-c/src/bin/m.rs index da71f071..62b1bdf7 100644 --- a/candle-wasm-examples/llama2-c/src/bin/m.rs +++ b/candle-wasm-examples/llama2-c/src/bin/m.rs @@ -58,6 +58,11 @@ impl Model { Err(e) => Err(JsError::new(&e.to_string())), } } + #[wasm_bindgen] + pub fn get_seq_len(&mut self) -> usize { + let seq_len = self.inner.config.seq_len; + seq_len + } #[wasm_bindgen] pub fn init_with_prompt( @@ -65,6 +70,7 @@ impl Model { prompt: String, temp: f64, repeat_penalty: f32, + seed: u64, ) -> Result { // First reset the cache. { @@ -74,7 +80,7 @@ impl Model { } } let temp = if temp <= 0. { None } else { Some(temp) }; - self.logits_processor = LogitsProcessor::new(299792458, temp); + self.logits_processor = LogitsProcessor::new(seed, temp); self.repeat_penalty = repeat_penalty; self.tokens.clear(); let tokens = self diff --git a/candle-wasm-examples/llama2-c/src/worker.rs b/candle-wasm-examples/llama2-c/src/worker.rs index 3d187fcc..7e97b5da 100644 --- a/candle-wasm-examples/llama2-c/src/worker.rs +++ b/candle-wasm-examples/llama2-c/src/worker.rs @@ -51,7 +51,7 @@ fn read_tensor>( pub struct Model { pub cache: Cache, - config: Config, + pub config: Config, pub llama: Llama, pub tokenizer: Tokenizer, } -- cgit v1.2.3