summaryrefslogtreecommitdiff
path: root/candle-wasm-examples/segment-anything/src
diff options
context:
space:
mode:
authorRadamés Ajna <radamajna@gmail.com>2023-09-14 22:31:58 -0700
committerGitHub <noreply@github.com>2023-09-15 06:31:58 +0100
commit39157346cb96d7610ddb6cdc686aed32d12bba3d (patch)
treea2de0d69d5f5a59e4062750d93898f37e65cb44f /candle-wasm-examples/segment-anything/src
parent5cefbba75777547f97abd92affcf9ef10ac36163 (diff)
downloadcandle-39157346cb96d7610ddb6cdc686aed32d12bba3d.tar.gz
candle-39157346cb96d7610ddb6cdc686aed32d12bba3d.tar.bz2
candle-39157346cb96d7610ddb6cdc686aed32d12bba3d.zip
Add SAM UI Demo (#854)
* fix tensor flattening * send image data back * sam ui worker example * SAM example * resize container * no need for this
Diffstat (limited to 'candle-wasm-examples/segment-anything/src')
-rw-r--r--candle-wasm-examples/segment-anything/src/bin/m.rs22
1 files changed, 20 insertions, 2 deletions
diff --git a/candle-wasm-examples/segment-anything/src/bin/m.rs b/candle-wasm-examples/segment-anything/src/bin/m.rs
index b53f5b9b..949c18a0 100644
--- a/candle-wasm-examples/segment-anything/src/bin/m.rs
+++ b/candle-wasm-examples/segment-anything/src/bin/m.rs
@@ -98,7 +98,7 @@ impl Model {
Some((x, y)),
false,
)?;
- let iou = iou_predictions.to_vec1::<f32>()?[0];
+ let iou = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?[0];
let mask_shape = mask.dims().to_vec();
let mask_data = mask.ge(0f32)?.flatten_all()?.to_vec1::<u8>()?;
let mask = Mask {
@@ -106,7 +106,13 @@ impl Model {
mask_shape,
mask_data,
};
- let json = serde_json::to_string(&mask)?;
+ let image = Image {
+ original_width: embeddings.original_width,
+ original_height: embeddings.original_height,
+ width: embeddings.width,
+ height: embeddings.height,
+ };
+ let json = serde_json::to_string(&MaskImage { mask, image })?;
Ok(json)
}
}
@@ -117,6 +123,18 @@ struct Mask {
mask_shape: Vec<usize>,
mask_data: Vec<u8>,
}
+#[derive(serde::Serialize, serde::Deserialize)]
+struct Image {
+ original_width: u32,
+ original_height: u32,
+ width: u32,
+ height: u32,
+}
+#[derive(serde::Serialize, serde::Deserialize)]
+struct MaskImage {
+ mask: Mask,
+ image: Image,
+}
fn main() {
console_error_panic_hook::set_once();