summaryrefslogtreecommitdiff
path: root/candle-wasm-examples
diff options
context:
space:
mode:
authorRadamés Ajna <radamajna@gmail.com>2023-12-14 04:04:17 -0800
committerGitHub <noreply@github.com>2023-12-14 06:04:17 -0600
commit104e196d468d6c440c9f1fc504be37b2cbfb9722 (patch)
treedeee30a79df97f3f3a2454462880b251a9f78af4 /candle-wasm-examples
parent5e33c85c8f7d2ae8c5fe8de557b69c036e4f080a (diff)
downloadcandle-104e196d468d6c440c9f1fc504be37b2cbfb9722.tar.gz
candle-104e196d468d6c440c9f1fc504be37b2cbfb9722.tar.bz2
candle-104e196d468d6c440c9f1fc504be37b2cbfb9722.zip
Phi 2 wasm (#1432)
* add phi 2.0 quantized model wasm * cols * spell * bug
Diffstat (limited to 'candle-wasm-examples')
-rw-r--r--candle-wasm-examples/phi/index.html90
-rw-r--r--candle-wasm-examples/phi/phiWorker.js17
-rw-r--r--candle-wasm-examples/phi/src/bin/m.rs21
3 files changed, 102 insertions, 26 deletions
diff --git a/candle-wasm-examples/phi/index.html b/candle-wasm-examples/phi/index.html
index 19c6a586..dbef698a 100644
--- a/candle-wasm-examples/phi/index.html
+++ b/candle-wasm-examples/phi/index.html
@@ -1,7 +1,7 @@
<html>
<head>
<meta content="text/html;charset=utf-8" http-equiv="Content-Type" />
- <title>Candle Phi 1.5 Rust/WASM</title>
+ <title>Candle Phi 1.5 / Phi 2.0 Rust/WASM</title>
</head>
<body></body>
</html>
@@ -39,7 +39,7 @@
import hljs from "https://cdn.skypack.dev/highlight.js";
// models base url
const MODELS = {
- phi_1_5_quantized: {
+ phi_1_5_q4k: {
base_url:
"https://huggingface.co/lmz/candle-quantized-phi/resolve/main/",
model: "model-q4k.gguf",
@@ -49,7 +49,7 @@
seq_len: 2048,
size: "800 MB",
},
- phi_1_5_quantized_2: {
+ phi_1_5_q80: {
base_url:
"https://huggingface.co/lmz/candle-quantized-phi/resolve/main/",
model: "model-q80.gguf",
@@ -59,7 +59,21 @@
seq_len: 2048,
size: "1.51 GB",
},
- puffin_phi_v2_quantized: {
+ phi_2_0_q4k: {
+ base_url:
+ "https://huggingface.co/radames/phi-2-quantized/resolve/main/",
+ model: [
+ "model-v2-q4k.gguf_aa.part",
+ "model-v2-q4k.gguf_ab.part",
+ "model-v2-q4k.gguf_ac.part",
+ ],
+ tokenizer: "tokenizer.json",
+ config: "config.json",
+ quantized: true,
+ seq_len: 2048,
+ size: "1.57GB",
+ },
+ puffin_phi_v2_q4k: {
base_url:
"https://huggingface.co/lmz/candle-quantized-phi/resolve/main/",
model: "model-puffin-phi-v2-q4k.gguf",
@@ -69,7 +83,7 @@
seq_len: 2048,
size: "798 MB",
},
- puffin_phi_v2_quantized_2: {
+ puffin_phi_v2_q80: {
base_url:
"https://huggingface.co/lmz/candle-quantized-phi/resolve/main/",
model: "model-puffin-phi-v2-q80.gguf",
@@ -106,8 +120,8 @@ Let’s think step by step.`,
},
{
title: "Question answering",
- prompt: `What is the capital of France?
-Answer:`,
+ prompt: `Instruct: What is the capital of France?
+Output:`,
},
{
title: "Chat mode",
@@ -148,7 +162,10 @@ Very polite review:`,
const getValue = (id) => document.querySelector(`#${id}`).value;
const modelID = getValue("model");
const model = MODELS[modelID];
- const weightsURL = model.base_url + model.model;
+ const weightsURL =
+ model.model instanceof Array
+ ? model.model.map((m) => model.base_url + m)
+ : model.base_url + model.model;
const tokenizerURL = model.base_url + model.tokenizer;
const configURL = model.base_url + model.config;
@@ -246,6 +263,13 @@ Very polite review:`,
option.innerText = `${id} (${model.size})`;
modelSelect.appendChild(option);
}
+ const query = new URLSearchParams(window.location.search);
+ const modelID = query.get("model");
+ if (modelID) {
+ modelSelect.value = modelID;
+ } else {
+ modelSelect.value = "phi_1_5_q4k";
+ }
for (const [i, { title, prompt }] of TEMPLATES.entries()) {
const div = document.createElement("div");
@@ -270,8 +294,18 @@ Very polite review:`,
prompt.value = template;
prompt.style.height = "auto";
prompt.style.height = prompt.scrollHeight + "px";
+ runBtn.disabled = false;
+ clearBtn.classList.remove("invisible");
});
modelSelect.addEventListener("change", (e) => {
+ const query = new URLSearchParams(window.location.search);
+ query.set("model", e.target.value);
+ window.history.replaceState(
+ {},
+ "",
+ `${window.location.pathname}?${query}`
+ );
+ window.parent.postMessage({ queryString: "?" + query }, "*");
const model = MODELS[e.target.value];
document.querySelector("#max-seq").max = model.seq_len;
document.querySelector("#max-seq").nextElementSibling.value = 200;
@@ -320,7 +354,7 @@ Very polite review:`,
<main class="grid grid-cols-1 gap-8 relative">
<span class="absolute text-5xl -ml-[1em]"> 🕯️ </span>
<div>
- <h1 class="text-5xl font-bold">Candle Phi 1.5</h1>
+ <h1 class="text-5xl font-bold">Candle Phi 1.5 / Phi 2.0</h1>
<h2 class="text-2xl font-bold">Rust/WASM Demo</h2>
<p class="max-w-lg">
The
@@ -330,10 +364,17 @@ Very polite review:`,
target="_blank"
>Phi-1.5</a
>
- model achieves state-of-the-art performance with only 1.3 billion
- parameters, compared to models with up to 10 billion. You can try the
- quantized version of the model here. Additional prompt examples are
- available in the
+ and
+ <a
+ href="https://huggingface.co/microsoft/phi-2"
+ class="link"
+ target="_blank"
+ >Phi-2</a
+ >
+ models achieve state-of-the-art performance with only 1.3 billion and
+ 2.7 billion parameters, compared to larger models with up to 13
+ billion parameters. Here you can try the quantized versions.
+ Additional prompt examples are available in the
<a
href="https://arxiv.org/pdf/2309.05463.pdf#page=8"
class="link"
@@ -350,7 +391,7 @@ Very polite review:`,
target="_blank"
>Puffin-Phi V2
</a>
- quantized version model, a fine-tuned version of Phi-1.5 on the
+ quantized version, a fine-tuned version of Phi-1.5 on the
<a
href="https://huggingface.co/datasets/LDJnr/Puffin"
class="link"
@@ -363,7 +404,7 @@ Very polite review:`,
<p class="text-xs italic max-w-lg">
<b>Note:</b>
When first run, the app will download and cache the model, which could
- take a few minutes. The models are <b>~800MB</b> or <b>~1.51GB</b> in
+ take a few minutes. The models are <b>~800MB</b> or <b>~1.57GB</b> in
size.
</p>
</div>
@@ -375,8 +416,13 @@ Very polite review:`,
></select>
</div>
<div>
- <h3 class="font-medium">Prompt Templates</h3>
- <form id="prompt-templates" class="flex flex-col gap-1 my-2"></form>
+ <details>
+ <summary class="font-medium cursor-pointer">Prompt Templates</summary>
+ <form
+ id="prompt-templates"
+ class="grid grid-cols-1 sm:grid-cols-2 gap-1 my-2"
+ ></form>
+ </details>
</div>
<form
id="form"
@@ -386,12 +432,12 @@ Very polite review:`,
<textarea
type="text"
id="prompt"
- class="font-light w-full px-3 py-2 mx-1 resize-none outline-none"
+ class="font-light text-lg w-full px-3 py-2 mx-1 resize-none outline-none"
oninput="this.style.height = 0;this.style.height = this.scrollHeight + 'px'"
placeholder="Add your prompt here..."
>
-Write a detailed analogy between mathematics and a lighthouse.
-Answer:</textarea
+Instruct: Write a detailed analogy between mathematics and a lighthouse.
+Output:</textarea
>
<button id="clear-btn">
<svg
@@ -517,9 +563,9 @@ Answer:</textarea
<div
id="output-counter"
hidden
- class="ml-auto font-semibold grid-rows-1 text-sm"
+ class="ml-auto font-semibold grid-rows-1"
></div>
- <p hidden id="output-generation" class="grid-rows-2"></p>
+ <p hidden id="output-generation" class="grid-rows-2 text-lg"></p>
<span id="output-status" class="m-auto font-light"
>No output yet</span
>
diff --git a/candle-wasm-examples/phi/phiWorker.js b/candle-wasm-examples/phi/phiWorker.js
index 5c030f1d..bb71b409 100644
--- a/candle-wasm-examples/phi/phiWorker.js
+++ b/candle-wasm-examples/phi/phiWorker.js
@@ -12,6 +12,20 @@ async function fetchArrayBuffer(url) {
cache.put(url, res.clone());
return new Uint8Array(await res.arrayBuffer());
}
+async function concatenateArrayBuffers(urls) {
+ const arrayBuffers = await Promise.all(urls.map(url => fetchArrayBuffer(url)));
+
+ let totalLength = arrayBuffers.reduce((acc, arrayBuffer) => acc + arrayBuffer.byteLength, 0);
+ let concatenatedBuffer = new Uint8Array(totalLength);
+
+ let offset = 0;
+ arrayBuffers.forEach(buffer => {
+ concatenatedBuffer.set(new Uint8Array(buffer), offset);
+ offset += buffer.byteLength;
+ });
+ return concatenatedBuffer;
+}
+
class Phi {
static instance = {};
@@ -27,10 +41,9 @@ class Phi {
await init();
self.postMessage({ status: "loading", message: "Loading Model" });
-
const [weightsArrayU8, tokenizerArrayU8, configArrayU8] =
await Promise.all([
- fetchArrayBuffer(weightsURL),
+ weightsURL instanceof Array ? concatenateArrayBuffers(weightsURL) : fetchArrayBuffer(weightsURL),
fetchArrayBuffer(tokenizerURL),
fetchArrayBuffer(configURL),
]);
diff --git a/candle-wasm-examples/phi/src/bin/m.rs b/candle-wasm-examples/phi/src/bin/m.rs
index c18e6c38..999f276d 100644
--- a/candle-wasm-examples/phi/src/bin/m.rs
+++ b/candle-wasm-examples/phi/src/bin/m.rs
@@ -5,6 +5,7 @@ use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausa
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
use candle_wasm_example_phi::console_log;
use js_sys::Date;
+use serde::Deserialize;
use tokenizers::Tokenizer;
use wasm_bindgen::prelude::*;
@@ -23,6 +24,12 @@ pub struct Model {
repeat_last_n: usize,
}
+#[derive(Debug, Clone, PartialEq, Deserialize)]
+
+pub struct ModelName {
+ pub _name_or_path: String,
+}
+
#[wasm_bindgen]
impl Model {
#[wasm_bindgen(constructor)]
@@ -34,15 +41,25 @@ impl Model {
) -> Result<Model, JsError> {
console_error_panic_hook::set_once();
console_log!("loading model");
+ let name: ModelName = serde_json::from_slice(&config)?;
let config: Config = serde_json::from_slice(&config)?;
+
+ console_log!("config loaded {:?}", name);
let tokenizer =
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
let start = Date::now();
+ console_log!("weights len: {:?}", weights.len());
let model = if quantized {
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(&weights)?;
- let model = QMixFormer::new(&config, vb)?;
- SelectedModel::Quantized(model)
+ console_log!("weights loaded");
+ if name._name_or_path == "microsoft/phi-2" {
+ let model = QMixFormer::new_v2(&config, vb)?;
+ SelectedModel::Quantized(model)
+ } else {
+ let model = QMixFormer::new(&config, vb)?;
+ SelectedModel::Quantized(model)
+ }
} else {
let device = &Device::Cpu;
let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, device)?;