summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Cargo.toml6
-rw-r--r--candle-wasm-examples/bert/Cargo.toml33
-rw-r--r--candle-wasm-examples/bert/README.md26
-rw-r--r--candle-wasm-examples/bert/bertWorker.js77
-rw-r--r--candle-wasm-examples/bert/build-lib.sh2
-rw-r--r--candle-wasm-examples/bert/lib-example.html368
-rw-r--r--candle-wasm-examples/bert/src/bin/m.rs92
-rw-r--r--candle-wasm-examples/bert/src/lib.rs20
-rw-r--r--candle-wasm-examples/bert/utils.js99
9 files changed, 719 insertions, 4 deletions
diff --git a/Cargo.toml b/Cargo.toml
index 3a5763a1..6cbbf00f 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -11,11 +11,9 @@ members = [
"candle-wasm-examples/segment-anything",
"candle-wasm-examples/whisper",
"candle-wasm-examples/yolo",
+ "candle-wasm-examples/bert",
]
-exclude = [
- "candle-flash-attn",
- "candle-kernels",
-]
+exclude = ["candle-flash-attn", "candle-kernels"]
resolver = "2"
[workspace.package]
diff --git a/candle-wasm-examples/bert/Cargo.toml b/candle-wasm-examples/bert/Cargo.toml
new file mode 100644
index 00000000..81a043de
--- /dev/null
+++ b/candle-wasm-examples/bert/Cargo.toml
@@ -0,0 +1,33 @@
+[package]
+name = "candle-wasm-example-bert"
+version.workspace = true
+edition.workspace = true
+description.workspace = true
+repository.workspace = true
+keywords.workspace = true
+categories.workspace = true
+license.workspace = true
+
+[dependencies]
+candle = { path = "../../candle-core", version = "0.2.2", package = "candle-core" }
+candle-nn = { path = "../../candle-nn", version = "0.2.2" }
+candle-transformers = { path = "../../candle-transformers", version = "0.2.2" }
+num-traits = { workspace = true }
+tokenizers = { workspace = true, features = ["unstable_wasm"] }
+
+# App crates.
+anyhow = { workspace = true }
+byteorder = { workspace = true }
+log = { workspace = true }
+rand = { workspace = true }
+serde = { workspace = true }
+serde_json = { workspace = true }
+safetensors = { workspace = true }
+
+# Wasm specific crates.
+console_error_panic_hook = "0.1.7"
+getrandom = { version = "0.2", features = ["js"] }
+gloo = "0.8"
+js-sys = "0.3.64"
+wasm-bindgen = "0.2.87"
+serde-wasm-bindgen = "0.6.0"
diff --git a/candle-wasm-examples/bert/README.md b/candle-wasm-examples/bert/README.md
new file mode 100644
index 00000000..c34d33cc
--- /dev/null
+++ b/candle-wasm-examples/bert/README.md
@@ -0,0 +1,26 @@
+## Running BERT with Candle and WASM
+
+Here, we provide two examples of how to run Bert using a Candle-compiled WASM binary and runtime.
+
+### Vanilla JS and WebWorkers
+
+To build and test the UI made in Vanilla JS and WebWorkers, first we need to build the WASM library:
+
+```bash
+sh build-lib.sh
+```
+
+This will bundle the library under `./build` and we can import it inside our WebWorker like a normal JS module:
+
+```js
+import init, { Model } from "./build/m.js";
+```
+
+The full example can be found under `./lib-example.html`. All needed assets are fetched from the web, so no need to download anything.
+Finally, you can preview the example by running a local HTTP server. For example:
+
+```bash
+python -m http.server
+```
+
+Then open `http://localhost:8000/lib-example.html` in your browser.
diff --git a/candle-wasm-examples/bert/bertWorker.js b/candle-wasm-examples/bert/bertWorker.js
new file mode 100644
index 00000000..fd796c2b
--- /dev/null
+++ b/candle-wasm-examples/bert/bertWorker.js
@@ -0,0 +1,77 @@
+//load Candle Bert Module wasm module
+import init, { Model } from "./build/m.js";
+
+async function fetchArrayBuffer(url) {
+ const cacheName = "bert-candle-cache";
+ const cache = await caches.open(cacheName);
+ const cachedResponse = await cache.match(url);
+ if (cachedResponse) {
+ const data = await cachedResponse.arrayBuffer();
+ return new Uint8Array(data);
+ }
+ const res = await fetch(url, { cache: "force-cache" });
+ cache.put(url, res.clone());
+ return new Uint8Array(await res.arrayBuffer());
+}
+class Bert {
+ static instance = {};
+
+ static async getInstance(weightsURL, tokenizerURL, configURL, modelID) {
+ if (!this.instance[modelID]) {
+ await init();
+
+ self.postMessage({ status: "loading", message: "Loading Model" });
+ const [weightsArrayU8, tokenizerArrayU8, mel_filtersArrayU8] =
+ await Promise.all([
+ fetchArrayBuffer(weightsURL),
+ fetchArrayBuffer(tokenizerURL),
+ fetchArrayBuffer(configURL),
+ ]);
+
+ this.instance[modelID] = new Model(
+ weightsArrayU8,
+ tokenizerArrayU8,
+ mel_filtersArrayU8
+ );
+ } else {
+ self.postMessage({ status: "ready", message: "Model Already Loaded" });
+ }
+ return this.instance[modelID];
+ }
+}
+
+self.addEventListener("message", async (event) => {
+ const {
+ weightsURL,
+ tokenizerURL,
+ configURL,
+ modelID,
+ sentences,
+ normalize = true,
+ } = event.data;
+ try {
+ self.postMessage({ status: "ready", message: "Starting Bert Model" });
+ const model = await Bert.getInstance(
+ weightsURL,
+ tokenizerURL,
+ configURL,
+ modelID
+ );
+ self.postMessage({
+ status: "embedding",
+ message: "Calculating Embeddings",
+ });
+ const output = model.get_embeddings({
+ sentences: sentences,
+ normalize_embeddings: normalize,
+ });
+
+ self.postMessage({
+ status: "complete",
+ message: "complete",
+ output: output.data,
+ });
+ } catch (e) {
+ self.postMessage({ error: e });
+ }
+});
diff --git a/candle-wasm-examples/bert/build-lib.sh b/candle-wasm-examples/bert/build-lib.sh
new file mode 100644
index 00000000..b0ebb182
--- /dev/null
+++ b/candle-wasm-examples/bert/build-lib.sh
@@ -0,0 +1,2 @@
+cargo build --target wasm32-unknown-unknown --release
+wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web
diff --git a/candle-wasm-examples/bert/lib-example.html b/candle-wasm-examples/bert/lib-example.html
new file mode 100644
index 00000000..d10ea1db
--- /dev/null
+++ b/candle-wasm-examples/bert/lib-example.html
@@ -0,0 +1,368 @@
+<html>
+ <head>
+ <meta content="text/html;charset=utf-8" http-equiv="Content-Type" />
+ <title>Candle Bert</title>
+ </head>
+ <body></body>
+</html>
+
+<!DOCTYPE html>
+<html>
+ <head>
+ <meta charset="UTF-8" />
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+ <style>
+ @import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap");
+ html,
+ body {
+ font-family: "Source Sans 3", sans-serif;
+ }
+ </style>
+ <script src="https://cdn.tailwindcss.com"></script>
+ <script type="module" src="./code.js"></script>
+ <script type="module">
+ import { hcl } from "https://cdn.skypack.dev/d3-color@3";
+ import { interpolateReds } from "https://cdn.skypack.dev/d3-scale-chromatic@3";
+ import { scaleLinear } from "https://cdn.skypack.dev/d3-scale@4";
+ import {
+ getModelInfo,
+ getEmbeddings,
+ getWikiText,
+ cosineSimilarity,
+ } from "./utils.js";
+
+ const bertWorker = new Worker("./bertWorker.js", {
+ type: "module",
+ });
+
+ const inputContainerEL = document.querySelector("#input-container");
+ const textAreaEl = document.querySelector("#input-area");
+ const outputAreaEl = document.querySelector("#output-area");
+ const formEl = document.querySelector("#form");
+ const searchInputEl = document.querySelector("#search-input");
+ const formWikiEl = document.querySelector("#form-wiki");
+ const searchWikiEl = document.querySelector("#search-wiki");
+ const outputStatusEl = document.querySelector("#output-status");
+ const modelSelectEl = document.querySelector("#model");
+
+ const sentencesRegex =
+ /(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<![A-Z]\.)(?<=\.|\?)\s/gm;
+
+ let sentenceEmbeddings = [];
+ let currInputText = "";
+ let isCalculating = false;
+
+ function toggleTextArea(state) {
+ if (state) {
+ textAreaEl.hidden = false;
+ textAreaEl.focus();
+ } else {
+ textAreaEl.hidden = true;
+ }
+ }
+ inputContainerEL.addEventListener("focus", (e) => {
+ toggleTextArea(true);
+ });
+ textAreaEl.addEventListener("blur", (e) => {
+ toggleTextArea(false);
+ });
+ textAreaEl.addEventListener("focusout", (e) => {
+ toggleTextArea(false);
+ if (currInputText === textAreaEl.value || isCalculating) return;
+ populateOutputArea(textAreaEl.value);
+ calculateEmbeddings(textAreaEl.value);
+ });
+
+ modelSelectEl.addEventListener("change", (e) => {
+ if (currInputText === "" || isCalculating) return;
+ populateOutputArea(textAreaEl.value);
+ calculateEmbeddings(textAreaEl.value);
+ });
+
+ function populateOutputArea(text) {
+ currInputText = text;
+ const sentences = text.split(sentencesRegex);
+
+ outputAreaEl.innerHTML = "";
+ for (const [id, sentence] of sentences.entries()) {
+ const sentenceEl = document.createElement("span");
+ sentenceEl.id = `sentence-${id}`;
+ sentenceEl.innerText = sentence + " ";
+ outputAreaEl.appendChild(sentenceEl);
+ }
+ }
+ formEl.addEventListener("submit", async (e) => {
+ e.preventDefault();
+ if (isCalculating || currInputText === "") return;
+ toggleInputs(true);
+ const modelID = modelSelectEl.value;
+ const { modelURL, tokenizerURL, configURL, search_prefix } =
+ getModelInfo(modelID);
+
+ const text = searchInputEl.value;
+ const query = search_prefix + searchInputEl.value;
+ outputStatusEl.classList.remove("invisible");
+ outputStatusEl.innerText = "Calculating embeddings for query...";
+ isCalculating = true;
+ const out = await getEmbeddings(
+ bertWorker,
+ modelURL,
+ tokenizerURL,
+ configURL,
+ modelID,
+ [query]
+ );
+ outputStatusEl.classList.add("invisible");
+ const queryEmbeddings = out.output[0];
+ // calculate cosine similarity with all sentences given the query
+ const distances = sentenceEmbeddings
+ .map((embedding, id) => ({
+ id,
+ similarity: cosineSimilarity(queryEmbeddings, embedding),
+ }))
+ .sort((a, b) => b.similarity - a.similarity)
+ // getting top 10 most similar sentences
+ .slice(0, 10);
+
+ const colorScale = scaleLinear()
+ .domain([
+ distances[distances.length - 1].similarity,
+ distances[0].similarity,
+ ])
+ .range([0, 1])
+ .interpolate(() => interpolateReds);
+ outputAreaEl.querySelectorAll("span").forEach((el) => {
+ el.style.color = "unset";
+ el.style.backgroundColor = "unset";
+ });
+ distances.forEach((d) => {
+ const el = outputAreaEl.querySelector(`#sentence-${d.id}`);
+ const color = colorScale(d.similarity);
+ const fontColor = hcl(color).l < 70 ? "white" : "black";
+ el.style.color = fontColor;
+ el.style.backgroundColor = color;
+ });
+
+ outputAreaEl
+ .querySelector(`#sentence-${distances[0].id}`)
+ .scrollIntoView({
+ behavior: "smooth",
+ block: "center",
+ inline: "nearest",
+ });
+
+ isCalculating = false;
+ toggleInputs(false);
+ });
+ async function calculateEmbeddings(text) {
+ isCalculating = true;
+ toggleInputs(true);
+ const modelID = modelSelectEl.value;
+ const { modelURL, tokenizerURL, configURL, document_prefix } =
+ getModelInfo(modelID);
+
+ const sentences = text.split(sentencesRegex);
+ const allEmbeddings = [];
+ outputStatusEl.classList.remove("invisible");
+ for (const [id, sentence] of sentences.entries()) {
+ const query = document_prefix + sentence;
+ outputStatusEl.innerText = `Calculating embeddings: sentence ${
+ id + 1
+ } of ${sentences.length}`;
+ const embeddings = await getEmbeddings(
+ bertWorker,
+ modelURL,
+ tokenizerURL,
+ configURL,
+ modelID,
+ [query],
+ updateStatus
+ );
+ allEmbeddings.push(embeddings);
+ }
+ outputStatusEl.classList.add("invisible");
+ sentenceEmbeddings = allEmbeddings.map((e) => e.output[0]);
+ isCalculating = false;
+ toggleInputs(false);
+ }
+
+ function updateStatus(data) {
+ if ("status" in data) {
+ if (data.status === "loading") {
+ outputStatusEl.innerText = data.message;
+ outputStatusEl.classList.remove("invisible");
+ }
+ }
+ }
+ function toggleInputs(state) {
+ const interactive = document.querySelectorAll(".interactive");
+ interactive.forEach((el) => {
+ if (state) {
+ el.disabled = true;
+ } else {
+ el.disabled = false;
+ }
+ });
+ }
+
+ searchWikiEl.addEventListener("input", () => {
+ searchWikiEl.setCustomValidity("");
+ });
+
+ formWikiEl.addEventListener("submit", async (e) => {
+ e.preventDefault();
+ if ("example" in e.submitter.dataset) {
+ searchWikiEl.value = e.submitter.innerText;
+ }
+ const text = searchWikiEl.value;
+
+ if (isCalculating || text === "") return;
+ try {
+ const wikiText = await getWikiText(text);
+ searchWikiEl.setCustomValidity("");
+ textAreaEl.innerHTML = wikiText;
+ populateOutputArea(wikiText);
+ calculateEmbeddings(wikiText);
+ searchWikiEl.value = "";
+ } catch {
+ searchWikiEl.setCustomValidity("Invalid Wikipedia article name");
+ searchWikiEl.reportValidity();
+ }
+ });
+ </script>
+ </head>
+ <body class="container max-w-4xl mx-auto p-4">
+ <main class="grid grid-cols-1 gap-5 relative">
+ <span class="absolute text-5xl -ml-[1em]"> 🕯️ </span>
+ <div>
+ <h1 class="text-5xl font-bold">Candle BERT</h1>
+ <h2 class="text-2xl font-bold">Rust/WASM Demo</h2>
+ <p class="max-w-lg">
+ Running sentence embeddings and similarity search in the browser using
+ the Bert Model written with
+ <a
+ href="https://github.com/huggingface/candle/"
+ target="_blank"
+ class="underline hover:text-blue-500 hover:no-underline"
+ >Candle
+ </a>
+ and compiled to Wasm. Embeddings models from are from
+ <a
+ href="https://huggingface.co/sentence-transformers/"
+ target="_blank"
+ class="underline hover:text-blue-500 hover:no-underline"
+ >
+ Sentence Transformers
+ </a>
+ and
+ <a
+ href="https://huggingface.co/intfloat/"
+ target="_blank"
+ class="underline hover:text-blue-500 hover:no-underline"
+ >
+ Liang Wang - e5 Models
+ </a>
+ </p>
+ </div>
+
+ <div>
+ <label for="model" class="font-medium block">Models Options: </label>
+ <select
+ id="model"
+ class="border-2 border-gray-500 rounded-md font-light interactive disabled:cursor-not-allowed w-full max-w-max"
+ >
+ <option value="intfloat_e5_small_v2" selected>
+ intfloat/e5-small-v2 (133 MB)
+ </option>
+ <option value="intfloat_e5_base_v2">
+ intfloat/e5-base-v2 (438 MB)
+ </option>
+ <option value="intfloat_multilingual_e5_small">
+ intfloat/multilingual-e5-small (471 MB)
+ </option>
+ <option value="sentence_transformers_all_MiniLM_L6_v2">
+ sentence-transformers/all-MiniLM-L6-v2 (90.9 MB)
+ </option>
+ <option value="sentence_transformers_all_MiniLM_L12_v2">
+ sentence-transformers/all-MiniLM-L12-v2 (133 MB)
+ </option>
+ </select>
+ </div>
+ <div>
+ <h3 class="font-medium">Examples:</h3>
+ <form
+ id="form-wiki"
+ class="flex text-xs rounded-md justify-between w-min gap-3"
+ >
+ <input type="submit" hidden />
+
+ <button data-example class="disabled:cursor-not-allowed interactive">
+ Pizza
+ </button>
+ <button data-example class="disabled:cursor-not-allowed interactive">
+ Paris
+ </button>
+ <button data-example class="disabled:cursor-not-allowed interactive">
+ Physics
+ </button>
+ <input
+ type="text"
+ id="search-wiki"
+ title="Search Wikipedia article by title"
+ class="font-light py-0 mx-1 resize-none outline-none w-32 disabled:cursor-not-allowed interactive"
+ placeholder="Load Wikipedia article..."
+ />
+ <button
+ title="Search Wikipedia article and load into input"
+ class="bg-gray-700 hover:bg-gray-800 text-white font-normal px-2 py-1 rounded disabled:bg-gray-300 disabled:cursor-not-allowed interactive"
+ >
+ Load
+ </button>
+ </form>
+ </div>
+ <form
+ id="form"
+ class="flex text-normal px-1 py-1 border border-gray-700 rounded-md items-center"
+ >
+ <input type="submit" hidden />
+ <input
+ type="text"
+ id="search-input"
+ class="font-light w-full px-3 py-2 mx-1 resize-none outline-none interactive disabled:cursor-not-allowed"
+ placeholder="Search query here..."
+ />
+ <button
+ class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 w-16 rounded disabled:bg-gray-300 disabled:cursor-not-allowed interactive"
+ >
+ Search
+ </button>
+ </form>
+ <div>
+ <h3 class="font-medium">Input text:</h3>
+ <div class="flex justify-between items-center">
+ <div class="rounded-md inline text-xs">
+ <span id="output-status" class="m-auto font-light invisible"
+ >C</span
+ >
+ </div>
+ </div>
+ <div
+ id="input-container"
+ tabindex="0"
+ class="min-h-[250px] bg-slate-100 text-gray-500 rounded-md p-4 flex flex-col gap-2 relative"
+ >
+ <textarea
+ id="input-area"
+ hidden
+ value=""
+ placeholder="Input text to perform semantic similarity search..."
+ class="flex-1 resize-none outline-none left-0 right-0 top-0 bottom-0 m-4 absolute interactive disabled:invisible"
+ ></textarea>
+ <p id="output-area" class="grid-rows-2">
+ Input text to perform semantic similarity search...
+ </p>
+ </div>
+ </div>
+ </main>
+ </body>
+</html>
diff --git a/candle-wasm-examples/bert/src/bin/m.rs b/candle-wasm-examples/bert/src/bin/m.rs
new file mode 100644
index 00000000..f5521abd
--- /dev/null
+++ b/candle-wasm-examples/bert/src/bin/m.rs
@@ -0,0 +1,92 @@
+use candle::{DType, Device, Tensor};
+use candle_nn::VarBuilder;
+use candle_transformers::models::bert::{BertModel, Config};
+use candle_wasm_example_bert::console_log;
+use tokenizers::{PaddingParams, Tokenizer};
+use wasm_bindgen::prelude::*;
+
+#[wasm_bindgen]
+pub struct Model {
+ bert: BertModel,
+ tokenizer: Tokenizer,
+}
+
+#[wasm_bindgen]
+impl Model {
+ #[wasm_bindgen(constructor)]
+ pub fn load(weights: Vec<u8>, tokenizer: Vec<u8>, config: Vec<u8>) -> Result<Model, JsError> {
+ console_error_panic_hook::set_once();
+ console_log!("loading model");
+ let device = &Device::Cpu;
+ let weights = safetensors::tensor::SafeTensors::deserialize(&weights)?;
+ let vb = VarBuilder::from_safetensors(vec![weights], DType::F64, device);
+ let config: Config = serde_json::from_slice(&config)?;
+ let tokenizer =
+ Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
+ let bert = BertModel::load(vb, &config)?;
+
+ Ok(Self { bert, tokenizer })
+ }
+
+ pub fn get_embeddings(&mut self, input: JsValue) -> Result<JsValue, JsError> {
+ let input: Params =
+ serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
+ let sentences = input.sentences;
+ let normalize_embeddings = input.normalize_embeddings;
+
+ let device = &Device::Cpu;
+ if let Some(pp) = self.tokenizer.get_padding_mut() {
+ pp.strategy = tokenizers::PaddingStrategy::BatchLongest
+ } else {
+ let pp = PaddingParams {
+ strategy: tokenizers::PaddingStrategy::BatchLongest,
+ ..Default::default()
+ };
+ self.tokenizer.with_padding(Some(pp));
+ }
+ let tokens = self
+ .tokenizer
+ .encode_batch(sentences.to_vec(), true)
+ .map_err(|m| JsError::new(&m.to_string()))?;
+
+ let token_ids: Vec<Tensor> = tokens
+ .iter()
+ .map(|tokens| {
+ let tokens = tokens.get_ids().to_vec();
+ Tensor::new(tokens.as_slice(), device)
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+
+ let token_ids = Tensor::stack(&token_ids, 0)?;
+ let token_type_ids = token_ids.zeros_like()?;
+ console_log!("running inference on batch {:?}", token_ids.shape());
+ let embeddings = self.bert.forward(&token_ids, &token_type_ids)?;
+ console_log!("generated embeddings {:?}", embeddings.shape());
+ // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
+ let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
+ let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
+ let embeddings = if normalize_embeddings {
+ embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)?
+ } else {
+ embeddings
+ };
+ let embeddings_data = embeddings.to_vec2()?;
+ Ok(serde_wasm_bindgen::to_value(&Embeddings {
+ data: embeddings_data,
+ })?)
+ }
+}
+
+#[derive(serde::Serialize, serde::Deserialize)]
+struct Embeddings {
+ data: Vec<Vec<f64>>,
+}
+
+#[derive(serde::Serialize, serde::Deserialize)]
+pub struct Params {
+ sentences: Vec<String>,
+ normalize_embeddings: bool,
+}
+fn main() {
+ console_error_panic_hook::set_once();
+}
diff --git a/candle-wasm-examples/bert/src/lib.rs b/candle-wasm-examples/bert/src/lib.rs
new file mode 100644
index 00000000..1e3657be
--- /dev/null
+++ b/candle-wasm-examples/bert/src/lib.rs
@@ -0,0 +1,20 @@
+use candle_transformers::models::bert;
+use wasm_bindgen::prelude::*;
+
+pub use bert::{BertModel, Config, DTYPE};
+pub use tokenizers::{PaddingParams, Tokenizer};
+
+#[wasm_bindgen]
+extern "C" {
+ // Use `js_namespace` here to bind `console.log(..)` instead of just
+ // `log(..)`
+ #[wasm_bindgen(js_namespace = console)]
+ pub fn log(s: &str);
+}
+
+#[macro_export]
+macro_rules! console_log {
+ // Note that this is using the `log` function imported above during
+ // `bare_bones`
+ ($($t:tt)*) => ($crate::log(&format_args!($($t)*).to_string()))
+}
diff --git a/candle-wasm-examples/bert/utils.js b/candle-wasm-examples/bert/utils.js
new file mode 100644
index 00000000..9d8bd7bd
--- /dev/null
+++ b/candle-wasm-examples/bert/utils.js
@@ -0,0 +1,99 @@
+export async function getEmbeddings(
+ worker,
+ weightsURL,
+ tokenizerURL,
+ configURL,
+ modelID,
+ sentences,
+ updateStatus = null
+) {
+ return new Promise((resolve, reject) => {
+ worker.postMessage({
+ weightsURL,
+ tokenizerURL,
+ configURL,
+ modelID,
+ sentences,
+ });
+ function messageHandler(event) {
+ if ("error" in event.data) {
+ worker.removeEventListener("message", messageHandler);
+ reject(new Error(event.data.error));
+ }
+ if (event.data.status === "complete") {
+ worker.removeEventListener("message", messageHandler);
+ resolve(event.data);
+ }
+ if (updateStatus) updateStatus(event.data);
+ }
+ worker.addEventListener("message", messageHandler);
+ });
+}
+
+const MODELS = {
+ intfloat_e5_small_v2: {
+ base_url: "https://huggingface.co/intfloat/e5-small-v2/resolve/main/",
+ search_prefix: "query: ",
+ document_prefix: "passage: ",
+ },
+ intfloat_e5_base_v2: {
+ base_url: "https://huggingface.co/intfloat/e5-base-v2/resolve/main/",
+ search_prefix: "query: ",
+ document_prefix: "passage:",
+ },
+ intfloat_multilingual_e5_small: {
+ base_url:
+ "https://huggingface.co/intfloat/multilingual-e5-small/resolve/main/",
+ search_prefix: "query: ",
+ document_prefix: "passage: ",
+ },
+ sentence_transformers_all_MiniLM_L6_v2: {
+ base_url:
+ "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/refs%2Fpr%2F21/",
+ search_prefix: "",
+ document_prefix: "",
+ },
+ sentence_transformers_all_MiniLM_L12_v2: {
+ base_url:
+ "https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/refs%2Fpr%2F4/",
+ search_prefix: "",
+ document_prefix: "",
+ },
+};
+export function getModelInfo(id) {
+ return {
+ modelURL: MODELS[id].base_url + "model.safetensors",
+ configURL: MODELS[id].base_url + "config.json",
+ tokenizerURL: MODELS[id].base_url + "tokenizer.json",
+ search_prefix: MODELS[id].search_prefix,
+ document_prefix: MODELS[id].document_prefix,
+ };
+}
+
+export function cosineSimilarity(vec1, vec2) {
+ const dot = vec1.reduce((acc, val, i) => acc + val * vec2[i], 0);
+ const a = Math.sqrt(vec1.reduce((acc, val) => acc + val * val, 0));
+ const b = Math.sqrt(vec2.reduce((acc, val) => acc + val * val, 0));
+ return dot / (a * b);
+}
+export async function getWikiText(article) {
+ // thanks to wikipedia for the API
+ const URL = `https://en.wikipedia.org/w/api.php?action=query&prop=extracts&exlimit=1&titles=${article}&explaintext=1&exsectionformat=plain&format=json&origin=*`;
+ return fetch(URL, {
+ method: "GET",
+ headers: {
+ Accept: "application/json",
+ },
+ })
+ .then((r) => r.json())
+ .then((data) => {
+ const pages = data.query.pages;
+ const pageId = Object.keys(pages)[0];
+ const extract = pages[pageId].extract;
+ if (extract === undefined || extract === "") {
+ throw new Error("No article found");
+ }
+ return extract;
+ })
+ .catch((error) => console.error("Error:", error));
+}