diff options
Diffstat (limited to 'candle-wasm-examples/segment-anything/samWorker.js')
-rw-r--r-- | candle-wasm-examples/segment-anything/samWorker.js | 155 |
1 files changed, 155 insertions, 0 deletions
diff --git a/candle-wasm-examples/segment-anything/samWorker.js b/candle-wasm-examples/segment-anything/samWorker.js new file mode 100644 index 00000000..c1a152ef --- /dev/null +++ b/candle-wasm-examples/segment-anything/samWorker.js @@ -0,0 +1,155 @@ +//load the candle SAM Model wasm module +import init, { Model } from "./build/m.js"; + +async function fetchArrayBuffer(url, cacheModel = true) { + if (!cacheModel) + return new Uint8Array(await (await fetch(url)).arrayBuffer()); + const cacheName = "sam-candle-cache"; + const cache = await caches.open(cacheName); + const cachedResponse = await cache.match(url); + if (cachedResponse) { + const data = await cachedResponse.arrayBuffer(); + return new Uint8Array(data); + } + const res = await fetch(url, { cache: "force-cache" }); + cache.put(url, res.clone()); + return new Uint8Array(await res.arrayBuffer()); +} +class SAMModel { + static instance = {}; + // keep current image embeddings state + static imageArrayHash = {}; + // Add a new property to hold the current modelID + static currentModelID = null; + + static async getInstance(modelURL, modelID) { + if (!this.instance[modelID]) { + await init(); + + self.postMessage({ + status: "loading", + message: `Loading Model ${modelID}`, + }); + const weightsArrayU8 = await fetchArrayBuffer(modelURL); + this.instance[modelID] = new Model( + weightsArrayU8, + /tiny|mobile/.test(modelID) + ); + } else { + self.postMessage({ status: "loading", message: "Model Already Loaded" }); + } + // Set the current modelID to the modelID that was passed in + this.currentModelID = modelID; + return this.instance[modelID]; + } + + // Remove the modelID parameter from setImageEmbeddings + static setImageEmbeddings(imageArrayU8) { + // check if image embeddings are already set for this image and model + const imageArrayHash = this.getSimpleHash(imageArrayU8); + if ( + this.imageArrayHash[this.currentModelID] === imageArrayHash && + this.instance[this.currentModelID] + ) { + self.postMessage({ + status: "embedding", + message: "Embeddings Already Set", + }); + return; + } + this.imageArrayHash[this.currentModelID] = imageArrayHash; + this.instance[this.currentModelID].set_image_embeddings(imageArrayU8); + self.postMessage({ status: "embedding", message: "Embeddings Set" }); + } + + static getSimpleHash(imageArrayU8) { + // get simple hash of imageArrayU8 + let imageArrayHash = 0; + for (let i = 0; i < imageArrayU8.length; i += 100) { + imageArrayHash ^= imageArrayU8[i]; + } + return imageArrayHash.toString(16); + } +} + +async function createImageCanvas( + { mask_shape, mask_data }, // mask + { original_width, original_height, width, height } // original image +) { + const [_, __, shape_width, shape_height] = mask_shape; + const maskCanvas = new OffscreenCanvas(shape_width, shape_height); // canvas for mask + const maskCtx = maskCanvas.getContext("2d"); + const canvas = new OffscreenCanvas(original_width, original_height); // canvas for creating mask with original image size + const ctx = canvas.getContext("2d"); + + const imageData = maskCtx.createImageData( + maskCanvas.width, + maskCanvas.height + ); + const data = imageData.data; + + for (let p = 0; p < data.length; p += 4) { + data[p] = 0; + data[p + 1] = 0; + data[p + 2] = 0; + data[p + 3] = mask_data[p / 4] * 255; + } + maskCtx.putImageData(imageData, 0, 0); + + let sx, sy; + if (original_height < original_width) { + sy = original_height / original_width; + sx = 1; + } else { + sy = 1; + sx = original_width / original_height; + } + ctx.drawImage( + maskCanvas, + 0, + 0, + maskCanvas.width * sx, + maskCanvas.height * sy, + 0, + 0, + original_width, + original_height + ); + + const blob = await canvas.convertToBlob(); + return URL.createObjectURL(blob); +} + +self.addEventListener("message", async (event) => { + const { modelURL, modelID, imageURL, points } = event.data; + try { + self.postMessage({ status: "loading", message: "Starting SAM" }); + const sam = await SAMModel.getInstance(modelURL, modelID); + + self.postMessage({ status: "loading", message: "Loading Image" }); + const imageArrayU8 = await fetchArrayBuffer(imageURL, false); + + self.postMessage({ status: "embedding", message: "Creating Embeddings" }); + SAMModel.setImageEmbeddings(imageArrayU8); + if (!points) { + // no points only do the embeddings + self.postMessage({ + status: "complete-embedding", + message: "Embeddings Complete", + }); + return; + } + + self.postMessage({ status: "segmenting", message: "Segmenting" }); + const { mask, image } = sam.mask_for_point(points.x, points.y); + const maskDataURL = await createImageCanvas(mask, image); + // Send the segment back to the main thread as JSON + self.postMessage({ + status: "complete", + message: "Segmentation Complete", + output: { maskURL: maskDataURL }, + }); + } catch (e) { + self.postMessage({ error: e }); + } +}); |