summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/segment-anything/main.rs')
-rw-r--r--candle-examples/examples/segment-anything/main.rs16
1 files changed, 14 insertions, 2 deletions
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs
index a749ba2a..4627248c 100644
--- a/candle-examples/examples/segment-anything/main.rs
+++ b/candle-examples/examples/segment-anything/main.rs
@@ -188,13 +188,25 @@ pub fn main() -> anyhow::Result<()> {
if args.generate_masks {
// Default options similar to the Python version.
- sam.generate_masks(
+ let bboxes = sam.generate_masks(
&image,
/* points_per_side */ 32,
/* crop_n_layer */ 0,
/* crop_overlap_ratio */ 512. / 1500.,
/* crop_n_points_downscale_factor */ 1,
- )?
+ )?;
+ for (idx, bbox) in bboxes.iter().enumerate() {
+ println!("{bbox:?}");
+ let mask = (&bbox.data.to_dtype(DType::U8)? * 255.)?;
+ let (h, w) = mask.dims2()?;
+ let mask = mask.broadcast_as((3, h, w))?;
+ candle_examples::save_image_resize(
+ &mask,
+ format!("sam_mask{idx}.png"),
+ initial_h,
+ initial_w,
+ )?;
+ }
} else {
let point = Some((args.point_x, args.point_y));
let (mask, iou_predictions) = sam.forward(&image, point, false)?;