summaryrefslogtreecommitdiff
path: root/candle-wasm-examples/segment-anything/samWorker.js
diff options
context:
space:
mode:
Diffstat (limited to 'candle-wasm-examples/segment-anything/samWorker.js')
-rw-r--r--candle-wasm-examples/segment-anything/samWorker.js155
1 files changed, 155 insertions, 0 deletions
diff --git a/candle-wasm-examples/segment-anything/samWorker.js b/candle-wasm-examples/segment-anything/samWorker.js
new file mode 100644
index 00000000..c1a152ef
--- /dev/null
+++ b/candle-wasm-examples/segment-anything/samWorker.js
@@ -0,0 +1,155 @@
+//load the candle SAM Model wasm module
+import init, { Model } from "./build/m.js";
+
+async function fetchArrayBuffer(url, cacheModel = true) {
+ if (!cacheModel)
+ return new Uint8Array(await (await fetch(url)).arrayBuffer());
+ const cacheName = "sam-candle-cache";
+ const cache = await caches.open(cacheName);
+ const cachedResponse = await cache.match(url);
+ if (cachedResponse) {
+ const data = await cachedResponse.arrayBuffer();
+ return new Uint8Array(data);
+ }
+ const res = await fetch(url, { cache: "force-cache" });
+ cache.put(url, res.clone());
+ return new Uint8Array(await res.arrayBuffer());
+}
+class SAMModel {
+ static instance = {};
+ // keep current image embeddings state
+ static imageArrayHash = {};
+ // Add a new property to hold the current modelID
+ static currentModelID = null;
+
+ static async getInstance(modelURL, modelID) {
+ if (!this.instance[modelID]) {
+ await init();
+
+ self.postMessage({
+ status: "loading",
+ message: `Loading Model ${modelID}`,
+ });
+ const weightsArrayU8 = await fetchArrayBuffer(modelURL);
+ this.instance[modelID] = new Model(
+ weightsArrayU8,
+ /tiny|mobile/.test(modelID)
+ );
+ } else {
+ self.postMessage({ status: "loading", message: "Model Already Loaded" });
+ }
+ // Set the current modelID to the modelID that was passed in
+ this.currentModelID = modelID;
+ return this.instance[modelID];
+ }
+
+ // Remove the modelID parameter from setImageEmbeddings
+ static setImageEmbeddings(imageArrayU8) {
+ // check if image embeddings are already set for this image and model
+ const imageArrayHash = this.getSimpleHash(imageArrayU8);
+ if (
+ this.imageArrayHash[this.currentModelID] === imageArrayHash &&
+ this.instance[this.currentModelID]
+ ) {
+ self.postMessage({
+ status: "embedding",
+ message: "Embeddings Already Set",
+ });
+ return;
+ }
+ this.imageArrayHash[this.currentModelID] = imageArrayHash;
+ this.instance[this.currentModelID].set_image_embeddings(imageArrayU8);
+ self.postMessage({ status: "embedding", message: "Embeddings Set" });
+ }
+
+ static getSimpleHash(imageArrayU8) {
+ // get simple hash of imageArrayU8
+ let imageArrayHash = 0;
+ for (let i = 0; i < imageArrayU8.length; i += 100) {
+ imageArrayHash ^= imageArrayU8[i];
+ }
+ return imageArrayHash.toString(16);
+ }
+}
+
+async function createImageCanvas(
+ { mask_shape, mask_data }, // mask
+ { original_width, original_height, width, height } // original image
+) {
+ const [_, __, shape_width, shape_height] = mask_shape;
+ const maskCanvas = new OffscreenCanvas(shape_width, shape_height); // canvas for mask
+ const maskCtx = maskCanvas.getContext("2d");
+ const canvas = new OffscreenCanvas(original_width, original_height); // canvas for creating mask with original image size
+ const ctx = canvas.getContext("2d");
+
+ const imageData = maskCtx.createImageData(
+ maskCanvas.width,
+ maskCanvas.height
+ );
+ const data = imageData.data;
+
+ for (let p = 0; p < data.length; p += 4) {
+ data[p] = 0;
+ data[p + 1] = 0;
+ data[p + 2] = 0;
+ data[p + 3] = mask_data[p / 4] * 255;
+ }
+ maskCtx.putImageData(imageData, 0, 0);
+
+ let sx, sy;
+ if (original_height < original_width) {
+ sy = original_height / original_width;
+ sx = 1;
+ } else {
+ sy = 1;
+ sx = original_width / original_height;
+ }
+ ctx.drawImage(
+ maskCanvas,
+ 0,
+ 0,
+ maskCanvas.width * sx,
+ maskCanvas.height * sy,
+ 0,
+ 0,
+ original_width,
+ original_height
+ );
+
+ const blob = await canvas.convertToBlob();
+ return URL.createObjectURL(blob);
+}
+
+self.addEventListener("message", async (event) => {
+ const { modelURL, modelID, imageURL, points } = event.data;
+ try {
+ self.postMessage({ status: "loading", message: "Starting SAM" });
+ const sam = await SAMModel.getInstance(modelURL, modelID);
+
+ self.postMessage({ status: "loading", message: "Loading Image" });
+ const imageArrayU8 = await fetchArrayBuffer(imageURL, false);
+
+ self.postMessage({ status: "embedding", message: "Creating Embeddings" });
+ SAMModel.setImageEmbeddings(imageArrayU8);
+ if (!points) {
+ // no points only do the embeddings
+ self.postMessage({
+ status: "complete-embedding",
+ message: "Embeddings Complete",
+ });
+ return;
+ }
+
+ self.postMessage({ status: "segmenting", message: "Segmenting" });
+ const { mask, image } = sam.mask_for_point(points.x, points.y);
+ const maskDataURL = await createImageCanvas(mask, image);
+ // Send the segment back to the main thread as JSON
+ self.postMessage({
+ status: "complete",
+ message: "Segmentation Complete",
+ output: { maskURL: maskDataURL },
+ });
+ } catch (e) {
+ self.postMessage({ error: e });
+ }
+});