summaryrefslogtreecommitdiff
path: root/candle-wasm-examples/t5/T5ModelEncoderWorker.js
diff options
context:
space:
mode:
Diffstat (limited to 'candle-wasm-examples/t5/T5ModelEncoderWorker.js')
-rw-r--r--candle-wasm-examples/t5/T5ModelEncoderWorker.js83
1 files changed, 83 insertions, 0 deletions
diff --git a/candle-wasm-examples/t5/T5ModelEncoderWorker.js b/candle-wasm-examples/t5/T5ModelEncoderWorker.js
new file mode 100644
index 00000000..a83b0ee0
--- /dev/null
+++ b/candle-wasm-examples/t5/T5ModelEncoderWorker.js
@@ -0,0 +1,83 @@
+//load Candle Bert Module wasm module
+let init, ModelEncoder;
+
+async function fetchArrayBuffer(url) {
+ const cacheName = "t5-candle-cache";
+ const cache = await caches.open(cacheName);
+ const cachedResponse = await cache.match(url);
+ if (cachedResponse) {
+ const data = await cachedResponse.arrayBuffer();
+ return new Uint8Array(data);
+ }
+ const res = await fetch(url, { cache: "force-cache" });
+ cache.put(url, res.clone());
+ return new Uint8Array(await res.arrayBuffer());
+}
+class Encoder {
+ static instance = {};
+
+ static async getInstance(weightsURL, tokenizerURL, configURL, modelID) {
+ if (modelID.includes("quantized")) {
+ ({ default: init, ModelEncoder } = await import(
+ "./build/m-quantized.js"
+ ));
+ } else {
+ ({ default: init, ModelEncoder } = await import("./build/m.js"));
+ }
+ if (!this.instance[modelID]) {
+ await init();
+
+ self.postMessage({ status: "loading", message: "Loading Model" });
+ const [weightsArrayU8, tokenizerArrayU8, configArrayU8] =
+ await Promise.all([
+ fetchArrayBuffer(weightsURL),
+ fetchArrayBuffer(tokenizerURL),
+ fetchArrayBuffer(configURL),
+ ]);
+
+ this.instance[modelID] = new ModelEncoder(
+ weightsArrayU8,
+ tokenizerArrayU8,
+ configArrayU8
+ );
+ } else {
+ self.postMessage({ status: "ready", message: "Model Already Loaded" });
+ }
+ return this.instance[modelID];
+ }
+}
+
+self.addEventListener("message", async (event) => {
+ const {
+ weightsURL,
+ tokenizerURL,
+ configURL,
+ modelID,
+ sentences,
+ normalize_embeddings,
+ } = event.data;
+ try {
+ self.postMessage({ status: "ready", message: "Starting T5 Encoder" });
+ const model = await Encoder.getInstance(
+ weightsURL,
+ tokenizerURL,
+ configURL,
+ modelID
+ );
+ self.postMessage({
+ status: "encoding",
+ message: "Encoding Sentences",
+ });
+ const output = model.decode({
+ sentences: sentences,
+ normalize_embeddings: normalize_embeddings || true,
+ });
+ self.postMessage({
+ status: "complete",
+ message: "complete",
+ output: output,
+ });
+ } catch (e) {
+ self.postMessage({ error: e });
+ }
+});