diff options
author | Radamés Ajna <radamajna@gmail.com> | 2023-09-14 22:31:58 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-15 06:31:58 +0100 |
commit | 39157346cb96d7610ddb6cdc686aed32d12bba3d (patch) | |
tree | a2de0d69d5f5a59e4062750d93898f37e65cb44f /candle-wasm-examples/segment-anything/src | |
parent | 5cefbba75777547f97abd92affcf9ef10ac36163 (diff) | |
download | candle-39157346cb96d7610ddb6cdc686aed32d12bba3d.tar.gz candle-39157346cb96d7610ddb6cdc686aed32d12bba3d.tar.bz2 candle-39157346cb96d7610ddb6cdc686aed32d12bba3d.zip |
Add SAM UI Demo (#854)
* fix tensor flattening
* send image data back
* sam ui worker example
* SAM example
* resize container
* no need for this
Diffstat (limited to 'candle-wasm-examples/segment-anything/src')
-rw-r--r-- | candle-wasm-examples/segment-anything/src/bin/m.rs | 22 |
1 files changed, 20 insertions, 2 deletions
diff --git a/candle-wasm-examples/segment-anything/src/bin/m.rs b/candle-wasm-examples/segment-anything/src/bin/m.rs index b53f5b9b..949c18a0 100644 --- a/candle-wasm-examples/segment-anything/src/bin/m.rs +++ b/candle-wasm-examples/segment-anything/src/bin/m.rs @@ -98,7 +98,7 @@ impl Model { Some((x, y)), false, )?; - let iou = iou_predictions.to_vec1::<f32>()?[0]; + let iou = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?[0]; let mask_shape = mask.dims().to_vec(); let mask_data = mask.ge(0f32)?.flatten_all()?.to_vec1::<u8>()?; let mask = Mask { @@ -106,7 +106,13 @@ impl Model { mask_shape, mask_data, }; - let json = serde_json::to_string(&mask)?; + let image = Image { + original_width: embeddings.original_width, + original_height: embeddings.original_height, + width: embeddings.width, + height: embeddings.height, + }; + let json = serde_json::to_string(&MaskImage { mask, image })?; Ok(json) } } @@ -117,6 +123,18 @@ struct Mask { mask_shape: Vec<usize>, mask_data: Vec<u8>, } +#[derive(serde::Serialize, serde::Deserialize)] +struct Image { + original_width: u32, + original_height: u32, + width: u32, + height: u32, +} +#[derive(serde::Serialize, serde::Deserialize)] +struct MaskImage { + mask: Mask, + image: Image, +} fn main() { console_error_panic_hook::set_once(); |