diff options
-rw-r--r-- | candle-transformers/src/models/mixformer.rs | 3 | ||||
-rw-r--r-- | candle-wasm-examples/phi/index.html | 212 | ||||
-rw-r--r-- | candle-wasm-examples/phi/phiWorker.js | 21 | ||||
-rw-r--r-- | candle-wasm-examples/phi/src/bin/m.rs | 9 |
4 files changed, 206 insertions, 39 deletions
diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index 0f2c199b..33aefbfe 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -4,11 +4,12 @@ use crate::models::with_tracing::{linear, Embedding as E, Linear}; /// https://arxiv.org/abs/2309.05463 use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::{Activation, VarBuilder}; +use serde::Deserialize; const MAX_SEQ_LEN: usize = 4096; // https://huggingface.co/microsoft/phi-1_5/blob/main/configuration_mixformer_sequential.py -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Deserialize)] pub struct Config { pub(crate) vocab_size: usize, pub(crate) n_positions: usize, diff --git a/candle-wasm-examples/phi/index.html b/candle-wasm-examples/phi/index.html index 6b6d589b..19c6a586 100644 --- a/candle-wasm-examples/phi/index.html +++ b/candle-wasm-examples/phi/index.html @@ -13,7 +13,8 @@ <meta name="viewport" content="width=device-width, initial-scale=1.0" /> <link rel="stylesheet" - href="https://cdn.jsdelivr.net/gh/highlightjs/cdn-release@11.8.0/build/styles/default.min.css" /> + href="https://cdn.jsdelivr.net/gh/highlightjs/cdn-release@11.8.0/build/styles/default.min.css" + /> <style> @import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap"); html, @@ -36,27 +37,110 @@ <script type="module"> import snarkdown from "https://cdn.skypack.dev/snarkdown"; import hljs from "https://cdn.skypack.dev/highlight.js"; - - const TOKENIZER_URL = - "https://huggingface.co/microsoft/phi-1_5/raw/main/tokenizer.json"; // models base url const MODELS = { phi_1_5_quantized: { base_url: "https://huggingface.co/lmz/candle-quantized-phi/resolve/main/", model: "model-q4k.gguf", + tokenizer: "tokenizer.json", + config: "phi-1_5.json", quantized: true, seq_len: 2048, + size: "800 MB", }, phi_1_5_quantized_2: { base_url: "https://huggingface.co/lmz/candle-quantized-phi/resolve/main/", model: "model-q80.gguf", + tokenizer: "tokenizer.json", + config: "phi-1_5.json", quantized: true, seq_len: 2048, + size: "1.51 GB", + }, + puffin_phi_v2_quantized: { + base_url: + "https://huggingface.co/lmz/candle-quantized-phi/resolve/main/", + model: "model-puffin-phi-v2-q4k.gguf", + tokenizer: "tokenizer-puffin-phi-v2.json", + config: "puffin-phi-v2.json", + quantized: true, + seq_len: 2048, + size: "798 MB", + }, + puffin_phi_v2_quantized_2: { + base_url: + "https://huggingface.co/lmz/candle-quantized-phi/resolve/main/", + model: "model-puffin-phi-v2-q80.gguf", + tokenizer: "tokenizer-puffin-phi-v2.json", + config: "puffin-phi-v2.json", + quantized: true, + seq_len: 2048, + size: "1.50 GB", }, }; + const TEMPLATES = [ + { + title: "Simple prompt", + prompt: `Sebastien is in London today, it’s the middle of July yet it’s raining, so Sebastien is feeling gloomy. He`, + }, + { + title: "Think step by step", + prompt: `Suppose Alice originally had 3 apples, then Bob gave Alice 7 apples, then Alice gave Cook 5 apples, and then Tim gave Alice 3x the amount of apples Alice had. How many apples does Alice have now? +Let’s think step by step.`, + }, + { + title: "Explaing a code snippet", + prompt: `What does this script do? +\`\`\`python +s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) +s.bind(('', 0)) +s.listen(1) +conn, addr = s.accept() +print('Connected by', addr) +return conn.getsockname()[1] +\`\`\` +Let’s think step by step.`, + }, + { + title: "Question answering", + prompt: `What is the capital of France? +Answer:`, + }, + { + title: "Chat mode", + prompt: `Alice: Can you tell me how to create a python application to go through all the files +in one directory where the file’s name DOES NOT end with '.json'? +Bob:`, + }, + { + title: "Python code completion", + prompt: `"""write a python function called batch(function, list) which call function(x) for x in +list in parallel""" +Solution:`, + }, + { + title: "Python Sample", + prompt: `"""Can you make sure those histograms appear side by side on the same plot: +\`\`\`python +plt.hist(intreps_retrained[0][1].view(64,-1).norm(dim=1).detach().cpu().numpy(), bins = 20) +plt.hist(intreps_pretrained[0][1].view(64,-1).norm(dim=1).detach().cpu().numpy(), bins = 20) +\`\`\` +"""`, + }, + { + title: "Write a Twitter post", + prompt: `Write a twitter post for the discovery of gravitational wave. +Twitter Post:`, + }, + { + title: "Write a review", + prompt: `Write a polite review complaining that the video game 'Random Game' was too badly optimized and it burned my laptop. +Very polite review:`, + }, + ]; const phiWorker = new Worker("./phiWorker.js", { type: "module", }); @@ -65,6 +149,8 @@ const modelID = getValue("model"); const model = MODELS[modelID]; const weightsURL = model.base_url + model.model; + const tokenizerURL = model.base_url + model.tokenizer; + const configURL = model.base_url + model.config; const prompt = getValue("prompt").trim(); const temperature = getValue("temperature"); @@ -107,7 +193,8 @@ phiWorker.postMessage({ weightsURL, modelID, - tokenizerURL: TOKENIZER_URL, + tokenizerURL, + configURL, quantized: model.quantized, prompt, temp: temperature, @@ -148,9 +235,42 @@ const clearBtn = document.querySelector("#clear-btn"); const runBtn = document.querySelector("#run"); const modelSelect = document.querySelector("#model"); + const promptTemplates = document.querySelector("#prompt-templates"); let runController = new AbortController(); let isRunning = false; + document.addEventListener("DOMContentLoaded", () => { + for (const [id, model] of Object.entries(MODELS)) { + const option = document.createElement("option"); + option.value = id; + option.innerText = `${id} (${model.size})`; + modelSelect.appendChild(option); + } + + for (const [i, { title, prompt }] of TEMPLATES.entries()) { + const div = document.createElement("div"); + const input = document.createElement("input"); + input.type = "radio"; + input.name = "task"; + input.id = `templates-${i}`; + input.classList.add("font-light", "cursor-pointer"); + input.value = prompt; + const label = document.createElement("label"); + label.htmlFor = `templates-${i}`; + label.classList.add("cursor-pointer"); + label.innerText = title; + div.appendChild(input); + div.appendChild(label); + promptTemplates.appendChild(div); + } + }); + + promptTemplates.addEventListener("change", (e) => { + const template = e.target.value; + prompt.value = template; + prompt.style.height = "auto"; + prompt.style.height = prompt.scrollHeight + "px"; + }); modelSelect.addEventListener("change", (e) => { const model = MODELS[e.target.value]; document.querySelector("#max-seq").max = model.seq_len; @@ -217,10 +337,27 @@ <a href="https://arxiv.org/pdf/2309.05463.pdf#page=8" class="link" - target="_blank"> + target="_blank" + > technical report </a >. </p> + <p class="max-w-lg"> + You can also try + <a + href="https://huggingface.co/teknium/Puffin-Phi-v2" + class="link" + target="_blank" + >Puffin-Phi V2 + </a> + quantized version model, a fine-tuned version of Phi-1.5 on the + <a + href="https://huggingface.co/datasets/LDJnr/Puffin" + class="link" + target="_blank" + >Puffin dataset + </a> + </p> </div> <div> <p class="text-xs italic max-w-lg"> @@ -234,26 +371,25 @@ <label for="model" class="font-medium">Models Options: </label> <select id="model" - class="border-2 border-gray-500 rounded-md font-light"> - <option value="phi_1_5_quantized" selected> - phi 1.5 quantized q4k (800 MB) - </option> - <option value="phi_1_5_quantized_2"> - phi 1.5 quantized q80 (1.51 GB) - </option> - <!-- <option value="phi_1_5">phi 1.5 (2.84 GB)</option> --> - </select> + class="border-2 border-gray-500 rounded-md font-light" + ></select> + </div> + <div> + <h3 class="font-medium">Prompt Templates</h3> + <form id="prompt-templates" class="flex flex-col gap-1 my-2"></form> </div> <form id="form" - class="flex text-normal px-1 py-1 border border-gray-700 rounded-md items-center"> + class="flex text-normal px-1 py-1 border border-gray-700 rounded-md items-center" + > <input type="submit" hidden /> <textarea type="text" id="prompt" class="font-light w-full px-3 py-2 mx-1 resize-none outline-none" oninput="this.style.height = 0;this.style.height = this.scrollHeight + 'px'" - placeholder="Add your prompt here..."> + placeholder="Add your prompt here..." + > Write a detailed analogy between mathematics and a lighthouse. Answer:</textarea > @@ -262,18 +398,21 @@ Answer:</textarea fill="none" xmlns="http://www.w3.org/2000/svg" width="40" - viewBox="0 0 70 40"> + viewBox="0 0 70 40" + > <path opacity=".5" d="M39 .2v40.2" stroke="#1F2937" /> <path d="M1.5 11.5 19 29.1m0-17.6L1.5 29.1" opacity=".5" stroke="#1F2937" - stroke-width="2" /> + stroke-width="2" + /> </svg> </button> <button id="run" - class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 w-16 rounded disabled:bg-gray-300 disabled:cursor-not-allowed"> + class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 w-16 rounded disabled:bg-gray-300 disabled:cursor-not-allowed" + > Run </button> </form> @@ -292,9 +431,11 @@ Answer:</textarea max="2048" step="1" value="200" - oninput="this.nextElementSibling.value = Number(this.value)" /> + oninput="this.nextElementSibling.value = Number(this.value)" + /> <output - class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"> + class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md" + > 200</output > <label class="text-sm font-medium" for="temperature" @@ -308,9 +449,11 @@ Answer:</textarea max="2" step="0.01" value="0.00" - oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" /> + oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" + /> <output - class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"> + class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md" + > 0.00</output > <label class="text-sm font-medium" for="top-p">Top-p</label> @@ -322,9 +465,11 @@ Answer:</textarea max="1" step="0.01" value="1.00" - oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" /> + oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" + /> <output - class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"> + class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md" + > 1.00</output > @@ -340,7 +485,8 @@ Answer:</textarea max="2" step="0.01" value="1.10" - oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" /> + oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" + /> <output class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md" >1.10</output @@ -351,11 +497,13 @@ Answer:</textarea id="seed" name="seed" value="299792458" - class="font-light border border-gray-700 text-right rounded-md p-2" /> + class="font-light border border-gray-700 text-right rounded-md p-2" + /> <button id="run" onclick="document.querySelector('#seed').value = Math.floor(Math.random() * Number.MAX_SAFE_INTEGER)" - class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-1 w-[50px] rounded disabled:bg-gray-300 disabled:cursor-not-allowed text-sm"> + class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-1 w-[50px] rounded disabled:bg-gray-300 disabled:cursor-not-allowed text-sm" + > Rand </button> </div> @@ -364,11 +512,13 @@ Answer:</textarea <div> <h3 class="font-medium">Generation:</h3> <div - class="min-h-[250px] bg-slate-100 text-gray-500 p-4 rounded-md flex flex-col gap-2"> + class="min-h-[250px] bg-slate-100 text-gray-500 p-4 rounded-md flex flex-col gap-2" + > <div id="output-counter" hidden - class="ml-auto font-semibold grid-rows-1 text-sm"></div> + class="ml-auto font-semibold grid-rows-1 text-sm" + ></div> <p hidden id="output-generation" class="grid-rows-2"></p> <span id="output-status" class="m-auto font-light" >No output yet</span diff --git a/candle-wasm-examples/phi/phiWorker.js b/candle-wasm-examples/phi/phiWorker.js index 17a03e14..5c030f1d 100644 --- a/candle-wasm-examples/phi/phiWorker.js +++ b/candle-wasm-examples/phi/phiWorker.js @@ -15,21 +15,30 @@ async function fetchArrayBuffer(url) { class Phi { static instance = {}; - static async getInstance(weightsURL, modelID, tokenizerURL, quantized) { + static async getInstance( + weightsURL, + modelID, + tokenizerURL, + configURL, + quantized + ) { // load individual modelID only once if (!this.instance[modelID]) { await init(); self.postMessage({ status: "loading", message: "Loading Model" }); - const [weightsArrayU8, tokenizerArrayU8] = await Promise.all([ - fetchArrayBuffer(weightsURL), - fetchArrayBuffer(tokenizerURL), - ]); + const [weightsArrayU8, tokenizerArrayU8, configArrayU8] = + await Promise.all([ + fetchArrayBuffer(weightsURL), + fetchArrayBuffer(tokenizerURL), + fetchArrayBuffer(configURL), + ]); this.instance[modelID] = new Model( weightsArrayU8, tokenizerArrayU8, + configArrayU8, quantized ); } @@ -52,6 +61,7 @@ async function generate(data) { weightsURL, modelID, tokenizerURL, + configURL, quantized, prompt, temp, @@ -66,6 +76,7 @@ async function generate(data) { weightsURL, modelID, tokenizerURL, + configURL, quantized ); diff --git a/candle-wasm-examples/phi/src/bin/m.rs b/candle-wasm-examples/phi/src/bin/m.rs index 8fb7db03..c18e6c38 100644 --- a/candle-wasm-examples/phi/src/bin/m.rs +++ b/candle-wasm-examples/phi/src/bin/m.rs @@ -26,10 +26,15 @@ pub struct Model { #[wasm_bindgen] impl Model { #[wasm_bindgen(constructor)] - pub fn load(weights: Vec<u8>, tokenizer: Vec<u8>, quantized: bool) -> Result<Model, JsError> { + pub fn load( + weights: Vec<u8>, + tokenizer: Vec<u8>, + config: Vec<u8>, + quantized: bool, + ) -> Result<Model, JsError> { console_error_panic_hook::set_once(); console_log!("loading model"); - let config: Config = Config::v1_5(); + let config: Config = serde_json::from_slice(&config)?; let tokenizer = Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?; let start = Date::now(); |