diff options
-rw-r--r-- | candle-wasm-examples/segment-anything/README.md | 26 | ||||
-rw-r--r-- | candle-wasm-examples/segment-anything/lib-example.html | 407 | ||||
-rw-r--r-- | candle-wasm-examples/segment-anything/samWorker.js | 156 | ||||
-rw-r--r-- | candle-wasm-examples/segment-anything/src/bin/m.rs | 22 |
4 files changed, 609 insertions, 2 deletions
diff --git a/candle-wasm-examples/segment-anything/README.md b/candle-wasm-examples/segment-anything/README.md new file mode 100644 index 00000000..04ff2033 --- /dev/null +++ b/candle-wasm-examples/segment-anything/README.md @@ -0,0 +1,26 @@ +## Running Segment Anything Example + +Here, we provide two examples of how to run Whisper using a Candle-compiled WASM binary and runtimes. + +### Vanilla JS and WebWorkers + +To build and test the UI made in Vanilla JS and WebWorkers, first we need to build the WASM library: + +```bash +sh build-lib.sh +``` + +This will bundle the library under `./build` and we can import it inside our WebWorker like a normal JS module: + +```js +import init, { Model } from "./build/m.js"; +``` + +The full example can be found under `./lib-example.html`. All needed assets are fetched from the web, so no need to download anything. +Finally, you can preview the example by running a local HTTP server. For example: + +```bash +python -m http.server +``` + +Then open `http://localhost:8000/lib-example.html` in your browser. diff --git a/candle-wasm-examples/segment-anything/lib-example.html b/candle-wasm-examples/segment-anything/lib-example.html new file mode 100644 index 00000000..127b9152 --- /dev/null +++ b/candle-wasm-examples/segment-anything/lib-example.html @@ -0,0 +1,407 @@ +<html> + <head> + <meta content="text/html;charset=utf-8" http-equiv="Content-Type" /> + <title>Candle Segment Anything Model (SAM) Rust/WASM</title> + </head> + <body></body> +</html> + +<!DOCTYPE html> +<html> + <head> + <meta charset="UTF-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + <style> + @import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap"); + html, + body { + font-family: "Source Sans 3", sans-serif; + } + </style> + <script src="https://cdn.tailwindcss.com"></script> + <script type="module"> + // base url for audio examples + const MODEL_BASEURL = + "https://huggingface.co/lmz/candle-sam/resolve/main/"; + + // models base url + const MODELS = { + sam_mobile_tiny: { + url: "mobile_sam-tiny-vitt.safetensors", + }, + sam_base: { + url: "sam_vit_b_01ec64.safetensors", + }, + }; + const samWorker = new Worker("./samWorker.js", { type: "module" }); + + async function segmentPoints( + modelURL, // URL to the weights file + modelID, // model ID + imageURL, // URL to the audio file + points // {x, y} points to prompt image + ) { + return new Promise((resolve, reject) => { + function messageHandler(event) { + console.log(event.data); + if ("status" in event.data) { + updateStatus(event.data); + } + if ("error" in event.data) { + samWorker.removeEventListener("message", messageHandler); + reject(new Error(event.data.error)); + } + if (event.data.status === "complete-embedding") { + samWorker.removeEventListener("message", messageHandler); + resolve(); + } + if (event.data.status === "complete") { + samWorker.removeEventListener("message", messageHandler); + resolve(event.data.output); + } + } + samWorker.addEventListener("message", messageHandler); + samWorker.postMessage({ + modelURL, + modelID, + imageURL, + points, + }); + }); + } + function updateStatus(statusMessage) { + statusOutput.innerText = event.data.message; + } + + const clearBtn = document.querySelector("#clear-btn"); + const canvas = document.querySelector("#canvas"); + const mask = document.querySelector("#mask"); + const ctxCanvas = canvas.getContext("2d"); + const ctxMask = mask.getContext("2d"); + const fileUpload = document.querySelector("#file-upload"); + const dropArea = document.querySelector("#drop-area"); + const dropButtons = document.querySelector("#drop-buttons"); + const imagesExamples = document.querySelector("#image-select"); + const modelSelection = document.querySelector("#model"); + const statusOutput = document.querySelector("#output-status"); + + //add event listener to file input + fileUpload.addEventListener("change", (e) => { + const target = e.target; + if (target.files.length > 0) { + const href = URL.createObjectURL(target.files[0]); + cleanImageCanvas(); + drawImageCanvas(href); + setImageEmbeddings(href); + } + }); + // add event listener to drop-area + dropArea.addEventListener("dragenter", (e) => { + e.preventDefault(); + dropArea.classList.add("border-blue-700"); + }); + dropArea.addEventListener("dragleave", (e) => { + e.preventDefault(); + dropArea.classList.remove("border-blue-700"); + }); + dropArea.addEventListener("dragover", (e) => { + e.preventDefault(); + }); + dropArea.addEventListener("drop", (e) => { + e.preventDefault(); + dropArea.classList.remove("border-blue-700"); + const url = e.dataTransfer.getData("text/uri-list"); + const files = e.dataTransfer.files; + + if (files.length > 0) { + const href = URL.createObjectURL(files[0]); + cleanImageCanvas(); + drawImageCanvas(href); + setImageEmbeddings(href); + } else if (url) { + cleanImageCanvas(); + drawImageCanvas(url); + setImageEmbeddings(url); + } + }); + + let hasImage = false; + let isSegmenting = false; + let isEmbedding = false; + let currentImageURL = ""; + //add event listener to image examples + imagesExamples.addEventListener("click", (e) => { + if (isEmbedding || isSegmenting) { + return; + } + const target = e.target; + if (target.nodeName === "IMG") { + const href = target.src; + cleanImageCanvas(); + drawImageCanvas(href); + setImageEmbeddings(href); + } + }); + //add event listener to clear button + clearBtn.addEventListener("click", () => { + cleanImageCanvas(); + }); + //add click event to canvas + canvas.addEventListener("click", async (event) => { + if (!hasImage || isEmbedding || isSegmenting) { + return; + } + const targetBox = event.target.getBoundingClientRect(); + const x = (event.clientX - targetBox.left) / targetBox.width; + const y = (event.clientY - targetBox.top) / targetBox.height; + isSegmenting = true; + const { maskURL } = await getSegmentationMask({ x, y }); + isSegmenting = false; + drawMask(maskURL); + }); + + async function getSegmentationMask(points) { + const modelID = modelSelection.value; + const modelURL = MODEL_BASEURL + MODELS[modelID].url; + const imageURL = currentImageURL; + const { maskURL } = await segmentPoints( + modelURL, + modelID, + imageURL, + points + ); + return { maskURL }; + } + async function setImageEmbeddings(imageURL) { + if (isEmbedding) { + return; + } + canvas.classList.remove("cursor-pointer"); + canvas.classList.add("cursor-wait"); + clearBtn.disabled = true; + const modelID = modelSelection.value; + const modelURL = MODEL_BASEURL + MODELS[modelID].url; + isEmbedding = true; + await segmentPoints(modelURL, modelID, imageURL); + canvas.classList.remove("cursor-wait"); + canvas.classList.add("cursor-pointer"); + clearBtn.disabled = false; + isEmbedding = false; + currentImageURL = imageURL; + } + + function cleanImageCanvas() { + ctxCanvas.clearRect(0, 0, canvas.width, canvas.height); + ctxMask.clearRect(0, 0, canvas.width, canvas.height); + hasImage = false; + isEmbedding = false; + isSegmenting = false; + currentImageURL = ""; + clearBtn.classList.add("invisible"); + canvas.parentElement.style.height = "auto"; + dropButtons.classList.remove("invisible"); + } + function drawMask(maskURL) { + if (!maskURL) { + throw new Error("No mask URL provided"); + } + + const img = new Image(); + img.crossOrigin = "anonymous"; + + img.onload = () => { + mask.width = canvas.width; + mask.height = canvas.height; + ctxMask.drawImage(canvas, 0, 0); + ctxMask.globalCompositeOperation = "source-atop"; + ctxMask.fillStyle = "rgba(255, 0, 0, 0.6)"; + ctxMask.fillRect(0, 0, canvas.width, canvas.height); + ctxMask.globalCompositeOperation = "destination-in"; + ctxMask.drawImage(img, 0, 0); + }; + img.src = maskURL; + } + function drawImageCanvas(imgURL) { + if (!imgURL) { + throw new Error("No image URL provided"); + } + + ctxCanvas.clearRect(0, 0, canvas.width, canvas.height); + ctxCanvas.clearRect(0, 0, canvas.width, canvas.height); + + const img = new Image(); + img.crossOrigin = "anonymous"; + + img.onload = () => { + canvas.width = img.width; + canvas.height = img.height; + ctxCanvas.drawImage(img, 0, 0); + canvas.parentElement.style.height = canvas.offsetHeight + "px"; + hasImage = true; + clearBtn.classList.remove("invisible"); + dropButtons.classList.add("invisible"); + }; + img.src = imgURL; + } + + const observer = new ResizeObserver((entries) => { + for (let entry of entries) { + if (entry.target === canvas) { + canvas.parentElement.style.height = canvas.offsetHeight + "px"; + } + } + }); + observer.observe(canvas); + </script> + </head> + <body class="container max-w-4xl mx-auto p-4"> + <main class="grid grid-cols-1 gap-8 relative"> + <span class="absolute text-5xl -ml-[1em]">🕯️</span> + <div> + <h1 class="text-5xl font-bold">Candle Segment Anything</h1> + <h2 class="text-2xl font-bold">Rust/WASM Demo</h2> + <p class="max-w-lg"> + Zero-shot image segmentation with + <a + href="https://segment-anything.com" + class="underline hover:text-blue-500 hover:no-underline" + target="_blank" + >Segment Anything Model (SAM)</a + > + and + <a + href="https://github.com/ChaoningZhang/MobileSAM" + class="underline hover:text-blue-500 hover:no-underline" + target="_blank" + >MobileSAM </a + >. It runs in the browser with a WASM runtime built with + <a + href="https://github.com/huggingface/candle/" + target="_blank" + class="underline hover:text-blue-500 hover:no-underline" + >Candle + </a> + </p> + </div> + <div> + <label for="model" class="font-medium">Models Options: </label> + <select + id="model" + class="border-2 border-gray-500 rounded-md font-light" + > + <option value="sam_mobile_tiny" selected> + Mobile SAM Tiny (40.6 MB) + </option> + <option value="sam_base">SAM Base (375 MB)</option> + </select> + </div> + <div> + <p class="text-xs italic max-w-lg"> + <b>Note:</b> + The model's first run may take a few seconds as it loads and caches + the model in the browser, and then creates the image embeddings. Any + subsequent clicks on points will be significantly faster. + </p> + </div> + <div class="relative max-w-lg"> + <div class="flex justify-between items-center"> + <div class="px-2 rounded-md inline text-xs"> + <span id="output-status" class="m-auto font-light"></span> + </div> + <button + id="clear-btn" + class="text-xs bg-white rounded-md disabled:opacity-50 flex gap-1 items-center invisible" + > + <svg + class="" + xmlns="http://www.w3.org/2000/svg" + viewBox="0 0 13 12" + height="1em" + > + <path + d="M1.6.7 12 11.1M12 .7 1.6 11.1" + stroke="#2E3036" + stroke-width="2" + /> + </svg> + Clear image + </button> + </div> + <div + id="drop-area" + class="flex flex-col items-center justify-center border-2 border-gray-300 border-dashed rounded-xl relative p-20 w-full overflow-hidden" + > + <div + id="drop-buttons" + class="flex flex-col items-center justify-center space-y-1 text-center relative z-10" + > + <svg + width="25" + height="25" + viewBox="0 0 25 25" + fill="none" + xmlns="http://www.w3.org/2000/svg" + > + <path + d="M3.5 24.3a3 3 0 0 1-1.9-.8c-.5-.5-.8-1.2-.8-1.9V2.9c0-.7.3-1.3.8-1.9.6-.5 1.2-.7 2-.7h18.6c.7 0 1.3.2 1.9.7.5.6.7 1.2.7 2v18.6c0 .7-.2 1.4-.7 1.9a3 3 0 0 1-2 .8H3.6Zm0-2.7h18.7V2.9H3.5v18.7Zm2.7-2.7h13.3c.3 0 .5 0 .6-.3v-.7l-3.7-5a.6.6 0 0 0-.6-.2c-.2 0-.4 0-.5.3l-3.5 4.6-2.4-3.3a.6.6 0 0 0-.6-.3c-.2 0-.4.1-.5.3l-2.7 3.6c-.1.2-.2.4 0 .7.1.2.3.3.6.3Z" + fill="#000" + /> + </svg> + <div class="flex text-sm text-gray-600"> + <label + for="file-upload" + class="relative cursor-pointer bg-white rounded-md font-medium text-blue-950 hover:text-blue-700" + > + <span>Drag and drop your image here</span> + <span class="block text-xs">or</span> + <span class="block text-xs">Click to upload</span> + </label> + </div> + <input + id="file-upload" + name="file-upload" + type="file" + class="sr-only" + /> + </div> + <canvas id="canvas" class="absolute w-full"></canvas> + <canvas + id="mask" + class="pointer-events-none absolute w-full" + ></canvas> + </div> + <div class="text-right py-2"> + <button + id="share-btn" + class="bg-white rounded-md hover:outline outline-orange-200 disabled:opacity-50 invisible" + > + <img + src="https://huggingface.co/datasets/huggingface/badges/raw/main/share-to-community-sm.svg" + /> + </button> + </div> + </div> + <div> + <div + class="flex gap-3 items-center overflow-x-scroll" + id="image-select" + > + <h3 class="font-medium">Examples:</h3> + + <img + src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/sf.jpg" + class="cursor-pointer w-24 h-24 object-cover" + /> + <img + src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/bike.jpeg" + class="cursor-pointer w-24 h-24 object-cover" + /> + <img + src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/000000000077.jpg" + class="cursor-pointer w-24 h-24 object-cover" + /> + </div> + </div> + </main> + </body> +</html> diff --git a/candle-wasm-examples/segment-anything/samWorker.js b/candle-wasm-examples/segment-anything/samWorker.js new file mode 100644 index 00000000..b90498de --- /dev/null +++ b/candle-wasm-examples/segment-anything/samWorker.js @@ -0,0 +1,156 @@ +//load the candle SAM Model wasm module +import init, { Model } from "./build/m.js"; + +async function fetchArrayBuffer(url, cacheModel = true) { + if (!cacheModel) + return new Uint8Array(await (await fetch(url)).arrayBuffer()); + const cacheName = "sam-candle-cache"; + const cache = await caches.open(cacheName); + const cachedResponse = await cache.match(url); + if (cachedResponse) { + const data = await cachedResponse.arrayBuffer(); + return new Uint8Array(data); + } + const res = await fetch(url, { cache: "force-cache" }); + cache.put(url, res.clone()); + return new Uint8Array(await res.arrayBuffer()); +} +class SAMModel { + static instance = {}; + // keep current image embeddings state + static imageArrayHash = {}; + // Add a new property to hold the current modelID + static currentModelID = null; + + static async getInstance(modelURL, modelID) { + if (!this.instance[modelID]) { + await init(); + + self.postMessage({ + status: "loading", + message: `Loading Model ${modelID}`, + }); + const weightsArrayU8 = await fetchArrayBuffer(modelURL); + this.instance[modelID] = new Model( + weightsArrayU8, + /tiny|mobile/.test(modelID) + ); + } else { + self.postMessage({ status: "loading", message: "Model Already Loaded" }); + } + // Set the current modelID to the modelID that was passed in + this.currentModelID = modelID; + return this.instance[modelID]; + } + + // Remove the modelID parameter from setImageEmbeddings + static setImageEmbeddings(imageArrayU8) { + // check if image embeddings are already set for this image and model + const imageArrayHash = this.getSimpleHash(imageArrayU8); + if ( + this.imageArrayHash[this.currentModelID] === imageArrayHash && + this.instance[this.currentModelID] + ) { + self.postMessage({ + status: "embedding", + message: "Embeddings Already Set", + }); + return; + } + this.imageArrayHash[this.currentModelID] = imageArrayHash; + this.instance[this.currentModelID].set_image_embeddings(imageArrayU8); + self.postMessage({ status: "embedding", message: "Embeddings Set" }); + } + + static getSimpleHash(imageArrayU8) { + // get simple hash of imageArrayU8 + let imageArrayHash = 0; + for (let i = 0; i < imageArrayU8.length; i += 100) { + imageArrayHash ^= imageArrayU8[i]; + } + return imageArrayHash.toString(16); + } +} + +async function createImageCanvas( + { mask_shape, mask_data }, // mask + { original_width, original_height, width, height } // original image +) { + const [_, __, shape_width, shape_height] = mask_shape; + const maskCanvas = new OffscreenCanvas(shape_width, shape_height); // canvas for mask + const maskCtx = maskCanvas.getContext("2d"); + const canvas = new OffscreenCanvas(original_width, original_height); // canvas for creating mask with original image size + const ctx = canvas.getContext("2d"); + + const imageData = maskCtx.createImageData( + maskCanvas.width, + maskCanvas.height + ); + const data = imageData.data; + + for (let p = 0; p < data.length; p += 4) { + data[p] = 0; + data[p + 1] = 0; + data[p + 2] = 0; + data[p + 3] = mask_data[p / 4] * 255; + } + maskCtx.putImageData(imageData, 0, 0); + + let sx, sy; + if (original_height < original_width) { + sy = original_height / original_width; + sx = 1; + } else { + sy = 1; + sx = original_width / original_height; + } + ctx.drawImage( + maskCanvas, + 0, + 0, + maskCanvas.width * sx, + maskCanvas.height * sy, + 0, + 0, + original_width, + original_height + ); + + const blob = await canvas.convertToBlob(); + return URL.createObjectURL(blob); +} + +self.addEventListener("message", async (event) => { + const { modelURL, modelID, imageURL, points } = event.data; + try { + self.postMessage({ status: "loading", message: "Starting SAM" }); + const sam = await SAMModel.getInstance(modelURL, modelID); + + self.postMessage({ status: "loading", message: "Loading Image" }); + const imageArrayU8 = await fetchArrayBuffer(imageURL, false); + + self.postMessage({ status: "embedding", message: "Creating Embeddings" }); + SAMModel.setImageEmbeddings(imageArrayU8); + if (!points) { + // no points only do the embeddings + self.postMessage({ + status: "complete-embedding", + message: "Embeddings Complete", + }); + return; + } + + self.postMessage({ status: "segmenting", message: "Segmenting" }); + const result = sam.mask_for_point(points.x, points.y); + const { mask, image } = JSON.parse(result); + const maskDataURL = await createImageCanvas(mask, image); + // Send the segment back to the main thread as JSON + self.postMessage({ + status: "complete", + message: "Segmentation Complete", + output: { maskURL: maskDataURL }, + }); + } catch (e) { + self.postMessage({ error: e }); + } +}); diff --git a/candle-wasm-examples/segment-anything/src/bin/m.rs b/candle-wasm-examples/segment-anything/src/bin/m.rs index b53f5b9b..949c18a0 100644 --- a/candle-wasm-examples/segment-anything/src/bin/m.rs +++ b/candle-wasm-examples/segment-anything/src/bin/m.rs @@ -98,7 +98,7 @@ impl Model { Some((x, y)), false, )?; - let iou = iou_predictions.to_vec1::<f32>()?[0]; + let iou = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?[0]; let mask_shape = mask.dims().to_vec(); let mask_data = mask.ge(0f32)?.flatten_all()?.to_vec1::<u8>()?; let mask = Mask { @@ -106,7 +106,13 @@ impl Model { mask_shape, mask_data, }; - let json = serde_json::to_string(&mask)?; + let image = Image { + original_width: embeddings.original_width, + original_height: embeddings.original_height, + width: embeddings.width, + height: embeddings.height, + }; + let json = serde_json::to_string(&MaskImage { mask, image })?; Ok(json) } } @@ -117,6 +123,18 @@ struct Mask { mask_shape: Vec<usize>, mask_data: Vec<u8>, } +#[derive(serde::Serialize, serde::Deserialize)] +struct Image { + original_width: u32, + original_height: u32, + width: u32, + height: u32, +} +#[derive(serde::Serialize, serde::Deserialize)] +struct MaskImage { + mask: Mask, + image: Image, +} fn main() { console_error_panic_hook::set_once(); |