summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/segment-anything/main.rs')
-rw-r--r--candle-examples/examples/segment-anything/main.rs62
1 files changed, 42 insertions, 20 deletions
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs
index c53c1010..0f0c0482 100644
--- a/candle-examples/examples/segment-anything/main.rs
+++ b/candle-examples/examples/segment-anything/main.rs
@@ -1,6 +1,5 @@
//! SAM: Segment Anything Model
//! https://github.com/facebookresearch/segment-anything
-#![allow(unused)]
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
@@ -14,7 +13,7 @@ pub mod model_prompt_encoder;
pub mod model_sam;
pub mod model_transformer;
-use candle::{DType, IndexOp, Result, Tensor, D};
+use candle::{DType, Result, Tensor};
use candle_nn::{Linear, Module, VarBuilder};
use clap::Parser;
@@ -101,6 +100,15 @@ struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
+
+ #[arg(long)]
+ generate_masks: bool,
+
+ #[arg(long)]
+ point_x: Option<f64>,
+
+ #[arg(long)]
+ point_y: Option<f64>,
}
pub fn main() -> anyhow::Result<()> {
@@ -108,7 +116,7 @@ pub fn main() -> anyhow::Result<()> {
let device = candle_examples::device(args.cpu)?;
- let image = if args.image.ends_with(".safetensors") {
+ let (image, initial_h, initial_w) = if args.image.ends_with(".safetensors") {
let mut tensors = candle::safetensors::load(&args.image, &device)?;
let image = match tensors.remove("image") {
Some(image) => image,
@@ -119,13 +127,16 @@ pub fn main() -> anyhow::Result<()> {
tensors.into_values().next().unwrap()
}
};
- if image.rank() == 4 {
+ let image = if image.rank() == 4 {
image.get(0)?
} else {
image
- }
+ };
+ let (_c, h, w) = image.dims3()?;
+ (image, h, w)
} else {
- candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?.to_device(&device)?
+ let (image, h, w) = candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?;
+ (image.to_device(&device)?, h, w)
};
println!("loaded image {image:?}");
@@ -142,19 +153,30 @@ pub fn main() -> anyhow::Result<()> {
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
let sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b
- let (mask, iou_predictions) = sam.forward(&image, false)?;
- println!("mask:\n{mask}");
- println!("iou_predictions: {iou_predictions:?}");
-
- // Save the mask as an image.
- let mask = mask.ge(&mask.zeros_like()?)?;
- let mask = (mask * 255.)?.squeeze(0)?;
- let (_one, h, w) = mask.dims3()?;
- let mask = mask.expand((3, h, w))?;
- candle_examples::save_image(&mask, "sam_mask.png")?;
-
- let image = sam.preprocess(&image)?;
- let image = sam.unpreprocess(&image)?.to_dtype(DType::U8)?;
- candle_examples::save_image(&image, "sam_input_scaled.png")?;
+ if args.generate_masks {
+ // Default options similar to the Python version.
+ sam.generate_masks(
+ &image,
+ /* points_per_side */ 32,
+ /* crop_n_layer */ 0,
+ /* crop_overlap_ratio */ 512. / 1500.,
+ /* crop_n_points_downscale_factor */ 1,
+ )?
+ } else {
+ let point = args.point_x.zip(args.point_y);
+ let (mask, iou_predictions) = sam.forward(&image, point, false)?;
+ println!("mask:\n{mask}");
+ println!("iou_predictions: {iou_predictions:?}");
+
+ // Save the mask as an image.
+ let mask = (mask.ge(&mask.zeros_like()?)? * 255.)?;
+ let (_one, h, w) = mask.dims3()?;
+ let mask = mask.expand((3, h, w))?;
+ candle_examples::save_image_resize(&mask, "sam_mask.png", initial_h, initial_w)?;
+
+ let image = sam.preprocess(&image)?;
+ let image = sam.unpreprocess(&image)?.to_dtype(DType::U8)?;
+ candle_examples::save_image(&image, "sam_input_scaled.png")?;
+ }
Ok(())
}