summaryrefslogtreecommitdiff
path: root/candle-wasm-examples/bert/bertWorker.js
diff options
context:
space:
mode:
Diffstat (limited to 'candle-wasm-examples/bert/bertWorker.js')
-rw-r--r--candle-wasm-examples/bert/bertWorker.js77
1 files changed, 77 insertions, 0 deletions
diff --git a/candle-wasm-examples/bert/bertWorker.js b/candle-wasm-examples/bert/bertWorker.js
new file mode 100644
index 00000000..fd796c2b
--- /dev/null
+++ b/candle-wasm-examples/bert/bertWorker.js
@@ -0,0 +1,77 @@
+//load Candle Bert Module wasm module
+import init, { Model } from "./build/m.js";
+
+async function fetchArrayBuffer(url) {
+ const cacheName = "bert-candle-cache";
+ const cache = await caches.open(cacheName);
+ const cachedResponse = await cache.match(url);
+ if (cachedResponse) {
+ const data = await cachedResponse.arrayBuffer();
+ return new Uint8Array(data);
+ }
+ const res = await fetch(url, { cache: "force-cache" });
+ cache.put(url, res.clone());
+ return new Uint8Array(await res.arrayBuffer());
+}
+class Bert {
+ static instance = {};
+
+ static async getInstance(weightsURL, tokenizerURL, configURL, modelID) {
+ if (!this.instance[modelID]) {
+ await init();
+
+ self.postMessage({ status: "loading", message: "Loading Model" });
+ const [weightsArrayU8, tokenizerArrayU8, mel_filtersArrayU8] =
+ await Promise.all([
+ fetchArrayBuffer(weightsURL),
+ fetchArrayBuffer(tokenizerURL),
+ fetchArrayBuffer(configURL),
+ ]);
+
+ this.instance[modelID] = new Model(
+ weightsArrayU8,
+ tokenizerArrayU8,
+ mel_filtersArrayU8
+ );
+ } else {
+ self.postMessage({ status: "ready", message: "Model Already Loaded" });
+ }
+ return this.instance[modelID];
+ }
+}
+
+self.addEventListener("message", async (event) => {
+ const {
+ weightsURL,
+ tokenizerURL,
+ configURL,
+ modelID,
+ sentences,
+ normalize = true,
+ } = event.data;
+ try {
+ self.postMessage({ status: "ready", message: "Starting Bert Model" });
+ const model = await Bert.getInstance(
+ weightsURL,
+ tokenizerURL,
+ configURL,
+ modelID
+ );
+ self.postMessage({
+ status: "embedding",
+ message: "Calculating Embeddings",
+ });
+ const output = model.get_embeddings({
+ sentences: sentences,
+ normalize_embeddings: normalize,
+ });
+
+ self.postMessage({
+ status: "complete",
+ message: "complete",
+ output: output.data,
+ });
+ } catch (e) {
+ self.postMessage({ error: e });
+ }
+});