summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-wasm-examples/segment-anything/README.md26
-rw-r--r--candle-wasm-examples/segment-anything/lib-example.html407
-rw-r--r--candle-wasm-examples/segment-anything/samWorker.js156
-rw-r--r--candle-wasm-examples/segment-anything/src/bin/m.rs22
4 files changed, 609 insertions, 2 deletions
diff --git a/candle-wasm-examples/segment-anything/README.md b/candle-wasm-examples/segment-anything/README.md
new file mode 100644
index 00000000..04ff2033
--- /dev/null
+++ b/candle-wasm-examples/segment-anything/README.md
@@ -0,0 +1,26 @@
+## Running Segment Anything Example
+
+Here, we provide two examples of how to run Whisper using a Candle-compiled WASM binary and runtimes.
+
+### Vanilla JS and WebWorkers
+
+To build and test the UI made in Vanilla JS and WebWorkers, first we need to build the WASM library:
+
+```bash
+sh build-lib.sh
+```
+
+This will bundle the library under `./build` and we can import it inside our WebWorker like a normal JS module:
+
+```js
+import init, { Model } from "./build/m.js";
+```
+
+The full example can be found under `./lib-example.html`. All needed assets are fetched from the web, so no need to download anything.
+Finally, you can preview the example by running a local HTTP server. For example:
+
+```bash
+python -m http.server
+```
+
+Then open `http://localhost:8000/lib-example.html` in your browser.
diff --git a/candle-wasm-examples/segment-anything/lib-example.html b/candle-wasm-examples/segment-anything/lib-example.html
new file mode 100644
index 00000000..127b9152
--- /dev/null
+++ b/candle-wasm-examples/segment-anything/lib-example.html
@@ -0,0 +1,407 @@
+<html>
+ <head>
+ <meta content="text/html;charset=utf-8" http-equiv="Content-Type" />
+ <title>Candle Segment Anything Model (SAM) Rust/WASM</title>
+ </head>
+ <body></body>
+</html>
+
+<!DOCTYPE html>
+<html>
+ <head>
+ <meta charset="UTF-8" />
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+ <style>
+ @import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap");
+ html,
+ body {
+ font-family: "Source Sans 3", sans-serif;
+ }
+ </style>
+ <script src="https://cdn.tailwindcss.com"></script>
+ <script type="module">
+ // base url for audio examples
+ const MODEL_BASEURL =
+ "https://huggingface.co/lmz/candle-sam/resolve/main/";
+
+ // models base url
+ const MODELS = {
+ sam_mobile_tiny: {
+ url: "mobile_sam-tiny-vitt.safetensors",
+ },
+ sam_base: {
+ url: "sam_vit_b_01ec64.safetensors",
+ },
+ };
+ const samWorker = new Worker("./samWorker.js", { type: "module" });
+
+ async function segmentPoints(
+ modelURL, // URL to the weights file
+ modelID, // model ID
+ imageURL, // URL to the audio file
+ points // {x, y} points to prompt image
+ ) {
+ return new Promise((resolve, reject) => {
+ function messageHandler(event) {
+ console.log(event.data);
+ if ("status" in event.data) {
+ updateStatus(event.data);
+ }
+ if ("error" in event.data) {
+ samWorker.removeEventListener("message", messageHandler);
+ reject(new Error(event.data.error));
+ }
+ if (event.data.status === "complete-embedding") {
+ samWorker.removeEventListener("message", messageHandler);
+ resolve();
+ }
+ if (event.data.status === "complete") {
+ samWorker.removeEventListener("message", messageHandler);
+ resolve(event.data.output);
+ }
+ }
+ samWorker.addEventListener("message", messageHandler);
+ samWorker.postMessage({
+ modelURL,
+ modelID,
+ imageURL,
+ points,
+ });
+ });
+ }
+ function updateStatus(statusMessage) {
+ statusOutput.innerText = event.data.message;
+ }
+
+ const clearBtn = document.querySelector("#clear-btn");
+ const canvas = document.querySelector("#canvas");
+ const mask = document.querySelector("#mask");
+ const ctxCanvas = canvas.getContext("2d");
+ const ctxMask = mask.getContext("2d");
+ const fileUpload = document.querySelector("#file-upload");
+ const dropArea = document.querySelector("#drop-area");
+ const dropButtons = document.querySelector("#drop-buttons");
+ const imagesExamples = document.querySelector("#image-select");
+ const modelSelection = document.querySelector("#model");
+ const statusOutput = document.querySelector("#output-status");
+
+ //add event listener to file input
+ fileUpload.addEventListener("change", (e) => {
+ const target = e.target;
+ if (target.files.length > 0) {
+ const href = URL.createObjectURL(target.files[0]);
+ cleanImageCanvas();
+ drawImageCanvas(href);
+ setImageEmbeddings(href);
+ }
+ });
+ // add event listener to drop-area
+ dropArea.addEventListener("dragenter", (e) => {
+ e.preventDefault();
+ dropArea.classList.add("border-blue-700");
+ });
+ dropArea.addEventListener("dragleave", (e) => {
+ e.preventDefault();
+ dropArea.classList.remove("border-blue-700");
+ });
+ dropArea.addEventListener("dragover", (e) => {
+ e.preventDefault();
+ });
+ dropArea.addEventListener("drop", (e) => {
+ e.preventDefault();
+ dropArea.classList.remove("border-blue-700");
+ const url = e.dataTransfer.getData("text/uri-list");
+ const files = e.dataTransfer.files;
+
+ if (files.length > 0) {
+ const href = URL.createObjectURL(files[0]);
+ cleanImageCanvas();
+ drawImageCanvas(href);
+ setImageEmbeddings(href);
+ } else if (url) {
+ cleanImageCanvas();
+ drawImageCanvas(url);
+ setImageEmbeddings(url);
+ }
+ });
+
+ let hasImage = false;
+ let isSegmenting = false;
+ let isEmbedding = false;
+ let currentImageURL = "";
+ //add event listener to image examples
+ imagesExamples.addEventListener("click", (e) => {
+ if (isEmbedding || isSegmenting) {
+ return;
+ }
+ const target = e.target;
+ if (target.nodeName === "IMG") {
+ const href = target.src;
+ cleanImageCanvas();
+ drawImageCanvas(href);
+ setImageEmbeddings(href);
+ }
+ });
+ //add event listener to clear button
+ clearBtn.addEventListener("click", () => {
+ cleanImageCanvas();
+ });
+ //add click event to canvas
+ canvas.addEventListener("click", async (event) => {
+ if (!hasImage || isEmbedding || isSegmenting) {
+ return;
+ }
+ const targetBox = event.target.getBoundingClientRect();
+ const x = (event.clientX - targetBox.left) / targetBox.width;
+ const y = (event.clientY - targetBox.top) / targetBox.height;
+ isSegmenting = true;
+ const { maskURL } = await getSegmentationMask({ x, y });
+ isSegmenting = false;
+ drawMask(maskURL);
+ });
+
+ async function getSegmentationMask(points) {
+ const modelID = modelSelection.value;
+ const modelURL = MODEL_BASEURL + MODELS[modelID].url;
+ const imageURL = currentImageURL;
+ const { maskURL } = await segmentPoints(
+ modelURL,
+ modelID,
+ imageURL,
+ points
+ );
+ return { maskURL };
+ }
+ async function setImageEmbeddings(imageURL) {
+ if (isEmbedding) {
+ return;
+ }
+ canvas.classList.remove("cursor-pointer");
+ canvas.classList.add("cursor-wait");
+ clearBtn.disabled = true;
+ const modelID = modelSelection.value;
+ const modelURL = MODEL_BASEURL + MODELS[modelID].url;
+ isEmbedding = true;
+ await segmentPoints(modelURL, modelID, imageURL);
+ canvas.classList.remove("cursor-wait");
+ canvas.classList.add("cursor-pointer");
+ clearBtn.disabled = false;
+ isEmbedding = false;
+ currentImageURL = imageURL;
+ }
+
+ function cleanImageCanvas() {
+ ctxCanvas.clearRect(0, 0, canvas.width, canvas.height);
+ ctxMask.clearRect(0, 0, canvas.width, canvas.height);
+ hasImage = false;
+ isEmbedding = false;
+ isSegmenting = false;
+ currentImageURL = "";
+ clearBtn.classList.add("invisible");
+ canvas.parentElement.style.height = "auto";
+ dropButtons.classList.remove("invisible");
+ }
+ function drawMask(maskURL) {
+ if (!maskURL) {
+ throw new Error("No mask URL provided");
+ }
+
+ const img = new Image();
+ img.crossOrigin = "anonymous";
+
+ img.onload = () => {
+ mask.width = canvas.width;
+ mask.height = canvas.height;
+ ctxMask.drawImage(canvas, 0, 0);
+ ctxMask.globalCompositeOperation = "source-atop";
+ ctxMask.fillStyle = "rgba(255, 0, 0, 0.6)";
+ ctxMask.fillRect(0, 0, canvas.width, canvas.height);
+ ctxMask.globalCompositeOperation = "destination-in";
+ ctxMask.drawImage(img, 0, 0);
+ };
+ img.src = maskURL;
+ }
+ function drawImageCanvas(imgURL) {
+ if (!imgURL) {
+ throw new Error("No image URL provided");
+ }
+
+ ctxCanvas.clearRect(0, 0, canvas.width, canvas.height);
+ ctxCanvas.clearRect(0, 0, canvas.width, canvas.height);
+
+ const img = new Image();
+ img.crossOrigin = "anonymous";
+
+ img.onload = () => {
+ canvas.width = img.width;
+ canvas.height = img.height;
+ ctxCanvas.drawImage(img, 0, 0);
+ canvas.parentElement.style.height = canvas.offsetHeight + "px";
+ hasImage = true;
+ clearBtn.classList.remove("invisible");
+ dropButtons.classList.add("invisible");
+ };
+ img.src = imgURL;
+ }
+
+ const observer = new ResizeObserver((entries) => {
+ for (let entry of entries) {
+ if (entry.target === canvas) {
+ canvas.parentElement.style.height = canvas.offsetHeight + "px";
+ }
+ }
+ });
+ observer.observe(canvas);
+ </script>
+ </head>
+ <body class="container max-w-4xl mx-auto p-4">
+ <main class="grid grid-cols-1 gap-8 relative">
+ <span class="absolute text-5xl -ml-[1em]">🕯️</span>
+ <div>
+ <h1 class="text-5xl font-bold">Candle Segment Anything</h1>
+ <h2 class="text-2xl font-bold">Rust/WASM Demo</h2>
+ <p class="max-w-lg">
+ Zero-shot image segmentation with
+ <a
+ href="https://segment-anything.com"
+ class="underline hover:text-blue-500 hover:no-underline"
+ target="_blank"
+ >Segment Anything Model (SAM)</a
+ >
+ and
+ <a
+ href="https://github.com/ChaoningZhang/MobileSAM"
+ class="underline hover:text-blue-500 hover:no-underline"
+ target="_blank"
+ >MobileSAM </a
+ >. It runs in the browser with a WASM runtime built with
+ <a
+ href="https://github.com/huggingface/candle/"
+ target="_blank"
+ class="underline hover:text-blue-500 hover:no-underline"
+ >Candle
+ </a>
+ </p>
+ </div>
+ <div>
+ <label for="model" class="font-medium">Models Options: </label>
+ <select
+ id="model"
+ class="border-2 border-gray-500 rounded-md font-light"
+ >
+ <option value="sam_mobile_tiny" selected>
+ Mobile SAM Tiny (40.6 MB)
+ </option>
+ <option value="sam_base">SAM Base (375 MB)</option>
+ </select>
+ </div>
+ <div>
+ <p class="text-xs italic max-w-lg">
+ <b>Note:</b>
+ The model's first run may take a few seconds as it loads and caches
+ the model in the browser, and then creates the image embeddings. Any
+ subsequent clicks on points will be significantly faster.
+ </p>
+ </div>
+ <div class="relative max-w-lg">
+ <div class="flex justify-between items-center">
+ <div class="px-2 rounded-md inline text-xs">
+ <span id="output-status" class="m-auto font-light"></span>
+ </div>
+ <button
+ id="clear-btn"
+ class="text-xs bg-white rounded-md disabled:opacity-50 flex gap-1 items-center invisible"
+ >
+ <svg
+ class=""
+ xmlns="http://www.w3.org/2000/svg"
+ viewBox="0 0 13 12"
+ height="1em"
+ >
+ <path
+ d="M1.6.7 12 11.1M12 .7 1.6 11.1"
+ stroke="#2E3036"
+ stroke-width="2"
+ />
+ </svg>
+ Clear image
+ </button>
+ </div>
+ <div
+ id="drop-area"
+ class="flex flex-col items-center justify-center border-2 border-gray-300 border-dashed rounded-xl relative p-20 w-full overflow-hidden"
+ >
+ <div
+ id="drop-buttons"
+ class="flex flex-col items-center justify-center space-y-1 text-center relative z-10"
+ >
+ <svg
+ width="25"
+ height="25"
+ viewBox="0 0 25 25"
+ fill="none"
+ xmlns="http://www.w3.org/2000/svg"
+ >
+ <path
+ d="M3.5 24.3a3 3 0 0 1-1.9-.8c-.5-.5-.8-1.2-.8-1.9V2.9c0-.7.3-1.3.8-1.9.6-.5 1.2-.7 2-.7h18.6c.7 0 1.3.2 1.9.7.5.6.7 1.2.7 2v18.6c0 .7-.2 1.4-.7 1.9a3 3 0 0 1-2 .8H3.6Zm0-2.7h18.7V2.9H3.5v18.7Zm2.7-2.7h13.3c.3 0 .5 0 .6-.3v-.7l-3.7-5a.6.6 0 0 0-.6-.2c-.2 0-.4 0-.5.3l-3.5 4.6-2.4-3.3a.6.6 0 0 0-.6-.3c-.2 0-.4.1-.5.3l-2.7 3.6c-.1.2-.2.4 0 .7.1.2.3.3.6.3Z"
+ fill="#000"
+ />
+ </svg>
+ <div class="flex text-sm text-gray-600">
+ <label
+ for="file-upload"
+ class="relative cursor-pointer bg-white rounded-md font-medium text-blue-950 hover:text-blue-700"
+ >
+ <span>Drag and drop your image here</span>
+ <span class="block text-xs">or</span>
+ <span class="block text-xs">Click to upload</span>
+ </label>
+ </div>
+ <input
+ id="file-upload"
+ name="file-upload"
+ type="file"
+ class="sr-only"
+ />
+ </div>
+ <canvas id="canvas" class="absolute w-full"></canvas>
+ <canvas
+ id="mask"
+ class="pointer-events-none absolute w-full"
+ ></canvas>
+ </div>
+ <div class="text-right py-2">
+ <button
+ id="share-btn"
+ class="bg-white rounded-md hover:outline outline-orange-200 disabled:opacity-50 invisible"
+ >
+ <img
+ src="https://huggingface.co/datasets/huggingface/badges/raw/main/share-to-community-sm.svg"
+ />
+ </button>
+ </div>
+ </div>
+ <div>
+ <div
+ class="flex gap-3 items-center overflow-x-scroll"
+ id="image-select"
+ >
+ <h3 class="font-medium">Examples:</h3>
+
+ <img
+ src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/sf.jpg"
+ class="cursor-pointer w-24 h-24 object-cover"
+ />
+ <img
+ src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/bike.jpeg"
+ class="cursor-pointer w-24 h-24 object-cover"
+ />
+ <img
+ src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/000000000077.jpg"
+ class="cursor-pointer w-24 h-24 object-cover"
+ />
+ </div>
+ </div>
+ </main>
+ </body>
+</html>
diff --git a/candle-wasm-examples/segment-anything/samWorker.js b/candle-wasm-examples/segment-anything/samWorker.js
new file mode 100644
index 00000000..b90498de
--- /dev/null
+++ b/candle-wasm-examples/segment-anything/samWorker.js
@@ -0,0 +1,156 @@
+//load the candle SAM Model wasm module
+import init, { Model } from "./build/m.js";
+
+async function fetchArrayBuffer(url, cacheModel = true) {
+ if (!cacheModel)
+ return new Uint8Array(await (await fetch(url)).arrayBuffer());
+ const cacheName = "sam-candle-cache";
+ const cache = await caches.open(cacheName);
+ const cachedResponse = await cache.match(url);
+ if (cachedResponse) {
+ const data = await cachedResponse.arrayBuffer();
+ return new Uint8Array(data);
+ }
+ const res = await fetch(url, { cache: "force-cache" });
+ cache.put(url, res.clone());
+ return new Uint8Array(await res.arrayBuffer());
+}
+class SAMModel {
+ static instance = {};
+ // keep current image embeddings state
+ static imageArrayHash = {};
+ // Add a new property to hold the current modelID
+ static currentModelID = null;
+
+ static async getInstance(modelURL, modelID) {
+ if (!this.instance[modelID]) {
+ await init();
+
+ self.postMessage({
+ status: "loading",
+ message: `Loading Model ${modelID}`,
+ });
+ const weightsArrayU8 = await fetchArrayBuffer(modelURL);
+ this.instance[modelID] = new Model(
+ weightsArrayU8,
+ /tiny|mobile/.test(modelID)
+ );
+ } else {
+ self.postMessage({ status: "loading", message: "Model Already Loaded" });
+ }
+ // Set the current modelID to the modelID that was passed in
+ this.currentModelID = modelID;
+ return this.instance[modelID];
+ }
+
+ // Remove the modelID parameter from setImageEmbeddings
+ static setImageEmbeddings(imageArrayU8) {
+ // check if image embeddings are already set for this image and model
+ const imageArrayHash = this.getSimpleHash(imageArrayU8);
+ if (
+ this.imageArrayHash[this.currentModelID] === imageArrayHash &&
+ this.instance[this.currentModelID]
+ ) {
+ self.postMessage({
+ status: "embedding",
+ message: "Embeddings Already Set",
+ });
+ return;
+ }
+ this.imageArrayHash[this.currentModelID] = imageArrayHash;
+ this.instance[this.currentModelID].set_image_embeddings(imageArrayU8);
+ self.postMessage({ status: "embedding", message: "Embeddings Set" });
+ }
+
+ static getSimpleHash(imageArrayU8) {
+ // get simple hash of imageArrayU8
+ let imageArrayHash = 0;
+ for (let i = 0; i < imageArrayU8.length; i += 100) {
+ imageArrayHash ^= imageArrayU8[i];
+ }
+ return imageArrayHash.toString(16);
+ }
+}
+
+async function createImageCanvas(
+ { mask_shape, mask_data }, // mask
+ { original_width, original_height, width, height } // original image
+) {
+ const [_, __, shape_width, shape_height] = mask_shape;
+ const maskCanvas = new OffscreenCanvas(shape_width, shape_height); // canvas for mask
+ const maskCtx = maskCanvas.getContext("2d");
+ const canvas = new OffscreenCanvas(original_width, original_height); // canvas for creating mask with original image size
+ const ctx = canvas.getContext("2d");
+
+ const imageData = maskCtx.createImageData(
+ maskCanvas.width,
+ maskCanvas.height
+ );
+ const data = imageData.data;
+
+ for (let p = 0; p < data.length; p += 4) {
+ data[p] = 0;
+ data[p + 1] = 0;
+ data[p + 2] = 0;
+ data[p + 3] = mask_data[p / 4] * 255;
+ }
+ maskCtx.putImageData(imageData, 0, 0);
+
+ let sx, sy;
+ if (original_height < original_width) {
+ sy = original_height / original_width;
+ sx = 1;
+ } else {
+ sy = 1;
+ sx = original_width / original_height;
+ }
+ ctx.drawImage(
+ maskCanvas,
+ 0,
+ 0,
+ maskCanvas.width * sx,
+ maskCanvas.height * sy,
+ 0,
+ 0,
+ original_width,
+ original_height
+ );
+
+ const blob = await canvas.convertToBlob();
+ return URL.createObjectURL(blob);
+}
+
+self.addEventListener("message", async (event) => {
+ const { modelURL, modelID, imageURL, points } = event.data;
+ try {
+ self.postMessage({ status: "loading", message: "Starting SAM" });
+ const sam = await SAMModel.getInstance(modelURL, modelID);
+
+ self.postMessage({ status: "loading", message: "Loading Image" });
+ const imageArrayU8 = await fetchArrayBuffer(imageURL, false);
+
+ self.postMessage({ status: "embedding", message: "Creating Embeddings" });
+ SAMModel.setImageEmbeddings(imageArrayU8);
+ if (!points) {
+ // no points only do the embeddings
+ self.postMessage({
+ status: "complete-embedding",
+ message: "Embeddings Complete",
+ });
+ return;
+ }
+
+ self.postMessage({ status: "segmenting", message: "Segmenting" });
+ const result = sam.mask_for_point(points.x, points.y);
+ const { mask, image } = JSON.parse(result);
+ const maskDataURL = await createImageCanvas(mask, image);
+ // Send the segment back to the main thread as JSON
+ self.postMessage({
+ status: "complete",
+ message: "Segmentation Complete",
+ output: { maskURL: maskDataURL },
+ });
+ } catch (e) {
+ self.postMessage({ error: e });
+ }
+});
diff --git a/candle-wasm-examples/segment-anything/src/bin/m.rs b/candle-wasm-examples/segment-anything/src/bin/m.rs
index b53f5b9b..949c18a0 100644
--- a/candle-wasm-examples/segment-anything/src/bin/m.rs
+++ b/candle-wasm-examples/segment-anything/src/bin/m.rs
@@ -98,7 +98,7 @@ impl Model {
Some((x, y)),
false,
)?;
- let iou = iou_predictions.to_vec1::<f32>()?[0];
+ let iou = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?[0];
let mask_shape = mask.dims().to_vec();
let mask_data = mask.ge(0f32)?.flatten_all()?.to_vec1::<u8>()?;
let mask = Mask {
@@ -106,7 +106,13 @@ impl Model {
mask_shape,
mask_data,
};
- let json = serde_json::to_string(&mask)?;
+ let image = Image {
+ original_width: embeddings.original_width,
+ original_height: embeddings.original_height,
+ width: embeddings.width,
+ height: embeddings.height,
+ };
+ let json = serde_json::to_string(&MaskImage { mask, image })?;
Ok(json)
}
}
@@ -117,6 +123,18 @@ struct Mask {
mask_shape: Vec<usize>,
mask_data: Vec<u8>,
}
+#[derive(serde::Serialize, serde::Deserialize)]
+struct Image {
+ original_width: u32,
+ original_height: u32,
+ width: u32,
+ height: u32,
+}
+#[derive(serde::Serialize, serde::Deserialize)]
+struct MaskImage {
+ mask: Mask,
+ image: Image,
+}
fn main() {
console_error_panic_hook::set_once();