diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-08 12:26:56 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-08 12:26:56 +0100 |
commit | 28c87f6a34e594aca5f558bceebc4c0a9c95911a (patch) | |
tree | 11d702a507de898a7e734aa22349657d04931fb4 /candle-examples/examples/segment-anything/main.rs | |
parent | c1453f00b11c9dd12c5aa81fb4355ce47d22d477 (diff) | |
download | candle-28c87f6a34e594aca5f558bceebc4c0a9c95911a.tar.gz candle-28c87f6a34e594aca5f558bceebc4c0a9c95911a.tar.bz2 candle-28c87f6a34e594aca5f558bceebc4c0a9c95911a.zip |
Automatic mask generator + point base mask (#773)
* Add more to the automatic mask generator.
* Add the target point.
* Fix.
* Remove the allow-unused.
* Mask post-processing.
Diffstat (limited to 'candle-examples/examples/segment-anything/main.rs')
-rw-r--r-- | candle-examples/examples/segment-anything/main.rs | 62 |
1 files changed, 42 insertions, 20 deletions
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs index c53c1010..0f0c0482 100644 --- a/candle-examples/examples/segment-anything/main.rs +++ b/candle-examples/examples/segment-anything/main.rs @@ -1,6 +1,5 @@ //! SAM: Segment Anything Model //! https://github.com/facebookresearch/segment-anything -#![allow(unused)] #[cfg(feature = "mkl")] extern crate intel_mkl_src; @@ -14,7 +13,7 @@ pub mod model_prompt_encoder; pub mod model_sam; pub mod model_transformer; -use candle::{DType, IndexOp, Result, Tensor, D}; +use candle::{DType, Result, Tensor}; use candle_nn::{Linear, Module, VarBuilder}; use clap::Parser; @@ -101,6 +100,15 @@ struct Args { /// Run on CPU rather than on GPU. #[arg(long)] cpu: bool, + + #[arg(long)] + generate_masks: bool, + + #[arg(long)] + point_x: Option<f64>, + + #[arg(long)] + point_y: Option<f64>, } pub fn main() -> anyhow::Result<()> { @@ -108,7 +116,7 @@ pub fn main() -> anyhow::Result<()> { let device = candle_examples::device(args.cpu)?; - let image = if args.image.ends_with(".safetensors") { + let (image, initial_h, initial_w) = if args.image.ends_with(".safetensors") { let mut tensors = candle::safetensors::load(&args.image, &device)?; let image = match tensors.remove("image") { Some(image) => image, @@ -119,13 +127,16 @@ pub fn main() -> anyhow::Result<()> { tensors.into_values().next().unwrap() } }; - if image.rank() == 4 { + let image = if image.rank() == 4 { image.get(0)? } else { image - } + }; + let (_c, h, w) = image.dims3()?; + (image, h, w) } else { - candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?.to_device(&device)? + let (image, h, w) = candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?; + (image.to_device(&device)?, h, w) }; println!("loaded image {image:?}"); @@ -142,19 +153,30 @@ pub fn main() -> anyhow::Result<()> { let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device); let sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b - let (mask, iou_predictions) = sam.forward(&image, false)?; - println!("mask:\n{mask}"); - println!("iou_predictions: {iou_predictions:?}"); - - // Save the mask as an image. - let mask = mask.ge(&mask.zeros_like()?)?; - let mask = (mask * 255.)?.squeeze(0)?; - let (_one, h, w) = mask.dims3()?; - let mask = mask.expand((3, h, w))?; - candle_examples::save_image(&mask, "sam_mask.png")?; - - let image = sam.preprocess(&image)?; - let image = sam.unpreprocess(&image)?.to_dtype(DType::U8)?; - candle_examples::save_image(&image, "sam_input_scaled.png")?; + if args.generate_masks { + // Default options similar to the Python version. + sam.generate_masks( + &image, + /* points_per_side */ 32, + /* crop_n_layer */ 0, + /* crop_overlap_ratio */ 512. / 1500., + /* crop_n_points_downscale_factor */ 1, + )? + } else { + let point = args.point_x.zip(args.point_y); + let (mask, iou_predictions) = sam.forward(&image, point, false)?; + println!("mask:\n{mask}"); + println!("iou_predictions: {iou_predictions:?}"); + + // Save the mask as an image. + let mask = (mask.ge(&mask.zeros_like()?)? * 255.)?; + let (_one, h, w) = mask.dims3()?; + let mask = mask.expand((3, h, w))?; + candle_examples::save_image_resize(&mask, "sam_mask.png", initial_h, initial_w)?; + + let image = sam.preprocess(&image)?; + let image = sam.unpreprocess(&image)?.to_dtype(DType::U8)?; + candle_examples::save_image(&image, "sam_input_scaled.png")?; + } Ok(()) } |