diff options
author | Radamés Ajna <radamajna@gmail.com> | 2023-12-14 04:04:17 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-14 06:04:17 -0600 |
commit | 104e196d468d6c440c9f1fc504be37b2cbfb9722 (patch) | |
tree | deee30a79df97f3f3a2454462880b251a9f78af4 /candle-wasm-examples | |
parent | 5e33c85c8f7d2ae8c5fe8de557b69c036e4f080a (diff) | |
download | candle-104e196d468d6c440c9f1fc504be37b2cbfb9722.tar.gz candle-104e196d468d6c440c9f1fc504be37b2cbfb9722.tar.bz2 candle-104e196d468d6c440c9f1fc504be37b2cbfb9722.zip |
Phi 2 wasm (#1432)
* add phi 2.0 quantized model wasm
* cols
* spell
* bug
Diffstat (limited to 'candle-wasm-examples')
-rw-r--r-- | candle-wasm-examples/phi/index.html | 90 | ||||
-rw-r--r-- | candle-wasm-examples/phi/phiWorker.js | 17 | ||||
-rw-r--r-- | candle-wasm-examples/phi/src/bin/m.rs | 21 |
3 files changed, 102 insertions, 26 deletions
diff --git a/candle-wasm-examples/phi/index.html b/candle-wasm-examples/phi/index.html index 19c6a586..dbef698a 100644 --- a/candle-wasm-examples/phi/index.html +++ b/candle-wasm-examples/phi/index.html @@ -1,7 +1,7 @@ <html> <head> <meta content="text/html;charset=utf-8" http-equiv="Content-Type" /> - <title>Candle Phi 1.5 Rust/WASM</title> + <title>Candle Phi 1.5 / Phi 2.0 Rust/WASM</title> </head> <body></body> </html> @@ -39,7 +39,7 @@ import hljs from "https://cdn.skypack.dev/highlight.js"; // models base url const MODELS = { - phi_1_5_quantized: { + phi_1_5_q4k: { base_url: "https://huggingface.co/lmz/candle-quantized-phi/resolve/main/", model: "model-q4k.gguf", @@ -49,7 +49,7 @@ seq_len: 2048, size: "800 MB", }, - phi_1_5_quantized_2: { + phi_1_5_q80: { base_url: "https://huggingface.co/lmz/candle-quantized-phi/resolve/main/", model: "model-q80.gguf", @@ -59,7 +59,21 @@ seq_len: 2048, size: "1.51 GB", }, - puffin_phi_v2_quantized: { + phi_2_0_q4k: { + base_url: + "https://huggingface.co/radames/phi-2-quantized/resolve/main/", + model: [ + "model-v2-q4k.gguf_aa.part", + "model-v2-q4k.gguf_ab.part", + "model-v2-q4k.gguf_ac.part", + ], + tokenizer: "tokenizer.json", + config: "config.json", + quantized: true, + seq_len: 2048, + size: "1.57GB", + }, + puffin_phi_v2_q4k: { base_url: "https://huggingface.co/lmz/candle-quantized-phi/resolve/main/", model: "model-puffin-phi-v2-q4k.gguf", @@ -69,7 +83,7 @@ seq_len: 2048, size: "798 MB", }, - puffin_phi_v2_quantized_2: { + puffin_phi_v2_q80: { base_url: "https://huggingface.co/lmz/candle-quantized-phi/resolve/main/", model: "model-puffin-phi-v2-q80.gguf", @@ -106,8 +120,8 @@ Let’s think step by step.`, }, { title: "Question answering", - prompt: `What is the capital of France? -Answer:`, + prompt: `Instruct: What is the capital of France? +Output:`, }, { title: "Chat mode", @@ -148,7 +162,10 @@ Very polite review:`, const getValue = (id) => document.querySelector(`#${id}`).value; const modelID = getValue("model"); const model = MODELS[modelID]; - const weightsURL = model.base_url + model.model; + const weightsURL = + model.model instanceof Array + ? model.model.map((m) => model.base_url + m) + : model.base_url + model.model; const tokenizerURL = model.base_url + model.tokenizer; const configURL = model.base_url + model.config; @@ -246,6 +263,13 @@ Very polite review:`, option.innerText = `${id} (${model.size})`; modelSelect.appendChild(option); } + const query = new URLSearchParams(window.location.search); + const modelID = query.get("model"); + if (modelID) { + modelSelect.value = modelID; + } else { + modelSelect.value = "phi_1_5_q4k"; + } for (const [i, { title, prompt }] of TEMPLATES.entries()) { const div = document.createElement("div"); @@ -270,8 +294,18 @@ Very polite review:`, prompt.value = template; prompt.style.height = "auto"; prompt.style.height = prompt.scrollHeight + "px"; + runBtn.disabled = false; + clearBtn.classList.remove("invisible"); }); modelSelect.addEventListener("change", (e) => { + const query = new URLSearchParams(window.location.search); + query.set("model", e.target.value); + window.history.replaceState( + {}, + "", + `${window.location.pathname}?${query}` + ); + window.parent.postMessage({ queryString: "?" + query }, "*"); const model = MODELS[e.target.value]; document.querySelector("#max-seq").max = model.seq_len; document.querySelector("#max-seq").nextElementSibling.value = 200; @@ -320,7 +354,7 @@ Very polite review:`, <main class="grid grid-cols-1 gap-8 relative"> <span class="absolute text-5xl -ml-[1em]"> 🕯️ </span> <div> - <h1 class="text-5xl font-bold">Candle Phi 1.5</h1> + <h1 class="text-5xl font-bold">Candle Phi 1.5 / Phi 2.0</h1> <h2 class="text-2xl font-bold">Rust/WASM Demo</h2> <p class="max-w-lg"> The @@ -330,10 +364,17 @@ Very polite review:`, target="_blank" >Phi-1.5</a > - model achieves state-of-the-art performance with only 1.3 billion - parameters, compared to models with up to 10 billion. You can try the - quantized version of the model here. Additional prompt examples are - available in the + and + <a + href="https://huggingface.co/microsoft/phi-2" + class="link" + target="_blank" + >Phi-2</a + > + models achieve state-of-the-art performance with only 1.3 billion and + 2.7 billion parameters, compared to larger models with up to 13 + billion parameters. Here you can try the quantized versions. + Additional prompt examples are available in the <a href="https://arxiv.org/pdf/2309.05463.pdf#page=8" class="link" @@ -350,7 +391,7 @@ Very polite review:`, target="_blank" >Puffin-Phi V2 </a> - quantized version model, a fine-tuned version of Phi-1.5 on the + quantized version, a fine-tuned version of Phi-1.5 on the <a href="https://huggingface.co/datasets/LDJnr/Puffin" class="link" @@ -363,7 +404,7 @@ Very polite review:`, <p class="text-xs italic max-w-lg"> <b>Note:</b> When first run, the app will download and cache the model, which could - take a few minutes. The models are <b>~800MB</b> or <b>~1.51GB</b> in + take a few minutes. The models are <b>~800MB</b> or <b>~1.57GB</b> in size. </p> </div> @@ -375,8 +416,13 @@ Very polite review:`, ></select> </div> <div> - <h3 class="font-medium">Prompt Templates</h3> - <form id="prompt-templates" class="flex flex-col gap-1 my-2"></form> + <details> + <summary class="font-medium cursor-pointer">Prompt Templates</summary> + <form + id="prompt-templates" + class="grid grid-cols-1 sm:grid-cols-2 gap-1 my-2" + ></form> + </details> </div> <form id="form" @@ -386,12 +432,12 @@ Very polite review:`, <textarea type="text" id="prompt" - class="font-light w-full px-3 py-2 mx-1 resize-none outline-none" + class="font-light text-lg w-full px-3 py-2 mx-1 resize-none outline-none" oninput="this.style.height = 0;this.style.height = this.scrollHeight + 'px'" placeholder="Add your prompt here..." > -Write a detailed analogy between mathematics and a lighthouse. -Answer:</textarea +Instruct: Write a detailed analogy between mathematics and a lighthouse. +Output:</textarea > <button id="clear-btn"> <svg @@ -517,9 +563,9 @@ Answer:</textarea <div id="output-counter" hidden - class="ml-auto font-semibold grid-rows-1 text-sm" + class="ml-auto font-semibold grid-rows-1" ></div> - <p hidden id="output-generation" class="grid-rows-2"></p> + <p hidden id="output-generation" class="grid-rows-2 text-lg"></p> <span id="output-status" class="m-auto font-light" >No output yet</span > diff --git a/candle-wasm-examples/phi/phiWorker.js b/candle-wasm-examples/phi/phiWorker.js index 5c030f1d..bb71b409 100644 --- a/candle-wasm-examples/phi/phiWorker.js +++ b/candle-wasm-examples/phi/phiWorker.js @@ -12,6 +12,20 @@ async function fetchArrayBuffer(url) { cache.put(url, res.clone()); return new Uint8Array(await res.arrayBuffer()); } +async function concatenateArrayBuffers(urls) { + const arrayBuffers = await Promise.all(urls.map(url => fetchArrayBuffer(url))); + + let totalLength = arrayBuffers.reduce((acc, arrayBuffer) => acc + arrayBuffer.byteLength, 0); + let concatenatedBuffer = new Uint8Array(totalLength); + + let offset = 0; + arrayBuffers.forEach(buffer => { + concatenatedBuffer.set(new Uint8Array(buffer), offset); + offset += buffer.byteLength; + }); + return concatenatedBuffer; +} + class Phi { static instance = {}; @@ -27,10 +41,9 @@ class Phi { await init(); self.postMessage({ status: "loading", message: "Loading Model" }); - const [weightsArrayU8, tokenizerArrayU8, configArrayU8] = await Promise.all([ - fetchArrayBuffer(weightsURL), + weightsURL instanceof Array ? concatenateArrayBuffers(weightsURL) : fetchArrayBuffer(weightsURL), fetchArrayBuffer(tokenizerURL), fetchArrayBuffer(configURL), ]); diff --git a/candle-wasm-examples/phi/src/bin/m.rs b/candle-wasm-examples/phi/src/bin/m.rs index c18e6c38..999f276d 100644 --- a/candle-wasm-examples/phi/src/bin/m.rs +++ b/candle-wasm-examples/phi/src/bin/m.rs @@ -5,6 +5,7 @@ use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausa use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer; use candle_wasm_example_phi::console_log; use js_sys::Date; +use serde::Deserialize; use tokenizers::Tokenizer; use wasm_bindgen::prelude::*; @@ -23,6 +24,12 @@ pub struct Model { repeat_last_n: usize, } +#[derive(Debug, Clone, PartialEq, Deserialize)] + +pub struct ModelName { + pub _name_or_path: String, +} + #[wasm_bindgen] impl Model { #[wasm_bindgen(constructor)] @@ -34,15 +41,25 @@ impl Model { ) -> Result<Model, JsError> { console_error_panic_hook::set_once(); console_log!("loading model"); + let name: ModelName = serde_json::from_slice(&config)?; let config: Config = serde_json::from_slice(&config)?; + + console_log!("config loaded {:?}", name); let tokenizer = Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?; let start = Date::now(); + console_log!("weights len: {:?}", weights.len()); let model = if quantized { let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(&weights)?; - let model = QMixFormer::new(&config, vb)?; - SelectedModel::Quantized(model) + console_log!("weights loaded"); + if name._name_or_path == "microsoft/phi-2" { + let model = QMixFormer::new_v2(&config, vb)?; + SelectedModel::Quantized(model) + } else { + let model = QMixFormer::new(&config, vb)?; + SelectedModel::Quantized(model) + } } else { let device = &Device::Cpu; let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, device)?; |