summaryrefslogtreecommitdiff
path: root/candle-wasm-examples/segment-anything
diff options
context:
space:
mode:
authorlichin-lin <vic20087cjimlin@gmail.com>2023-10-05 22:14:47 +0100
committerGitHub <noreply@github.com>2023-10-05 22:14:47 +0100
commit47c25a567bd14ab3e830b5d768cd80f33ed9545b (patch)
tree2a8a81db9d027e2fc9020958c27dc20ced161c0e /candle-wasm-examples/segment-anything
parent7f7d95e2c36ecda7349c51ba06ea4ba6cc7f2482 (diff)
downloadcandle-47c25a567bd14ab3e830b5d768cd80f33ed9545b.tar.gz
candle-47c25a567bd14ab3e830b5d768cd80f33ed9545b.tar.bz2
candle-47c25a567bd14ab3e830b5d768cd80f33ed9545b.zip
feat: [SAM] able to download the result as png (#1035)
* feat: able to download the result as png * feat: update function and wording
Diffstat (limited to 'candle-wasm-examples/segment-anything')
-rw-r--r--candle-wasm-examples/segment-anything/lib-example.html60
1 files changed, 60 insertions, 0 deletions
diff --git a/candle-wasm-examples/segment-anything/lib-example.html b/candle-wasm-examples/segment-anything/lib-example.html
index adcd02ab..f6b5931f 100644
--- a/candle-wasm-examples/segment-anything/lib-example.html
+++ b/candle-wasm-examples/segment-anything/lib-example.html
@@ -73,9 +73,12 @@
statusOutput.innerText = event.data.message;
}
+ let copyMaskURL = null;
+ let copyImageURL = null;
const clearBtn = document.querySelector("#clear-btn");
const maskBtn = document.querySelector("#mask-btn");
const undoBtn = document.querySelector("#undo-btn");
+ const downloadBtn = document.querySelector("#download-btn");
const canvas = document.querySelector("#canvas");
const mask = document.querySelector("#mask");
const ctxCanvas = canvas.getContext("2d");
@@ -93,6 +96,7 @@
if (target.files.length > 0) {
const href = URL.createObjectURL(target.files[0]);
clearImageCanvas();
+ copyImageURL = href;
drawImageCanvas(href);
setImageEmbeddings(href);
togglePointMode(false);
@@ -119,11 +123,13 @@
if (files.length > 0) {
const href = URL.createObjectURL(files[0]);
clearImageCanvas();
+ copyImageURL = href;
drawImageCanvas(href);
setImageEmbeddings(href);
togglePointMode(false);
} else if (url) {
clearImageCanvas();
+ copyImageURL = url;
drawImageCanvas(url);
setImageEmbeddings(url);
togglePointMode(false);
@@ -145,6 +151,7 @@
if (target.nodeName === "IMG") {
const href = target.src;
clearImageCanvas();
+ copyImageURL = href;
drawImageCanvas(href);
setImageEmbeddings(href);
}
@@ -163,6 +170,46 @@
undoBtn.addEventListener("click", () => {
undoPoint();
});
+ // add event to download btn
+ downloadBtn.addEventListener("click", async () => {
+ // Function to load image blobs as Image elements asynchronously
+ const loadImageAsync = (imageURL) => {
+ return new Promise((resolve) => {
+ const img = new Image();
+ img.onload = () => {
+ resolve(img);
+ };
+ img.crossOrigin = "anonymous";
+ img.src = imageURL;
+ });
+ };
+ const originalImage = await loadImageAsync(copyImageURL);
+ const maskImage = await loadImageAsync(copyMaskURL);
+
+ // create main a board to draw
+ const canvas = document.createElement("canvas");
+ const ctx = canvas.getContext("2d");
+ canvas.width = originalImage.width;
+ canvas.height = originalImage.height;
+
+ // Perform the mask operation
+ ctx.drawImage(maskImage, 0, 0);
+ ctx.globalCompositeOperation = "source-in";
+ ctx.drawImage(originalImage, 0, 0);
+
+ // to blob
+ const blobPromise = new Promise((resolve) => {
+ canvas.toBlob(resolve);
+ });
+ const blob = await blobPromise;
+ const resultURL = URL.createObjectURL(blob);
+
+ // download
+ const link = document.createElement("a");
+ link.href = resultURL;
+ link.download = "cutout.png";
+ link.click();
+ });
//add click event to canvas
canvas.addEventListener("click", async (event) => {
if (!hasImage || isEmbedding || isSegmenting) {
@@ -185,14 +232,17 @@
pointArr = [...pointArr, [x, y, !backgroundMode]];
}
undoBtn.disabled = false;
+ downloadBtn.disabled = false;
if (pointArr.length == 0) {
ctxMask.clearRect(0, 0, canvas.width, canvas.height);
undoBtn.disabled = true;
+ downloadBtn.disabled = true;
return;
}
isSegmenting = true;
const { maskURL } = await getSegmentationMask(pointArr);
isSegmenting = false;
+ copyMaskURL = maskURL;
drawMask(maskURL, pointArr);
});
@@ -212,6 +262,7 @@
isSegmenting = true;
const { maskURL } = await getSegmentationMask(pointArr);
isSegmenting = false;
+ copyMaskURL = maskURL;
drawMask(maskURL, pointArr);
}
function togglePointMode(mode) {
@@ -490,6 +541,15 @@
<img
src="https://huggingface.co/datasets/huggingface/badges/raw/main/share-to-community-sm.svg" />
</button>
+
+ <button
+ id="download-btn"
+ title="Copy result (.png)"
+ disabled
+ class="p-1 px-2 text-xs font-medium bg-white rounded-2xl outline outline-gray-200 hover:outline-orange-200 disabled:opacity-50"
+ >
+ Download Cut-Out
+ </button>
</div>
</div>
<div>