diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-12 14:35:55 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-12 14:35:55 +0100 |
commit | 42da17694a4214a3e39e0d64afc22635ce83f557 (patch) | |
tree | d996a2dd9cbe3eb4aef17dc26fccddc957326553 | |
parent | 25aacda28eff79d14be09d636275b0351d81069d (diff) | |
download | candle-42da17694a4214a3e39e0d64afc22635ce83f557.tar.gz candle-42da17694a4214a3e39e0d64afc22635ce83f557.tar.bz2 candle-42da17694a4214a3e39e0d64afc22635ce83f557.zip |
Segment Anything readme (#827)
* Add a readme for the segment-anything model.
* Add the original image.
* Clean-up the segment anything cli example.
* Also print the mask id in the outputs.
-rw-r--r-- | candle-examples/examples/segment-anything/README.md | 40 | ||||
-rw-r--r-- | candle-examples/examples/segment-anything/assets/sam_merged.jpg | bin | 0 -> 160984 bytes | |||
-rw-r--r-- | candle-examples/examples/segment-anything/main.rs | 110 | ||||
-rw-r--r-- | candle-examples/examples/yolo-v8/assets/peoples.pp.jpg | bin | 81845 -> 0 bytes |
4 files changed, 82 insertions, 68 deletions
diff --git a/candle-examples/examples/segment-anything/README.md b/candle-examples/examples/segment-anything/README.md new file mode 100644 index 00000000..3c5b034f --- /dev/null +++ b/candle-examples/examples/segment-anything/README.md @@ -0,0 +1,40 @@ +# candle-segment-anything: Segment-Anything Model + +This example is based on Meta AI [Segment-Anything +Model](https://github.com/facebookresearch/segment-anything). This model +provides a robust and fast image segmentation pipeline that can be tweaked via +some prompting (requesting some points to be in the target mask, requesting some +points to be part of the background so _not_ in the target mask, specifying some +bounding box). + +The default backbone can be replaced by the smaller and faster TinyViT model +based on [MobileSAM](https://github.com/ChaoningZhang/MobileSAM). + +## Running some example. + +```bash +cargo run --example segment-anything --release -- \ + --image candle-examples/examples/yolo-v8/assets/bike.jpg + --use-tiny + --point-x 0.4 + --point-y 0.3 +``` + +Running this command generates a `sam_merged.jpg` file containing the original +image with a blue overlay of the selected mask. The red dot represents the prompt +specified by `--point-x 0.4 --point-y 0.3`, this prompt is assumed to be part +of the target mask. + +The values used for `--point-x` and `--point-y` should be between 0 and 1 and +are proportional to the image dimension, i.e. use 0.5 for the image center. + + + + + +### Command-line flags +- `--use-tiny`: use the TinyViT based MobileSAM backbone rather than the default + one. +- `--point-x`, `--point-y`: specifies the location of the target point. +- `--threshold`: sets the threshold value to be part of the mask, a negative + value results in a larger mask and can be specified via `--threshold=-1.2`. diff --git a/candle-examples/examples/segment-anything/assets/sam_merged.jpg b/candle-examples/examples/segment-anything/assets/sam_merged.jpg Binary files differnew file mode 100644 index 00000000..a5f64e5e --- /dev/null +++ b/candle-examples/examples/segment-anything/assets/sam_merged.jpg diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs index 21ba0415..3d9898b6 100644 --- a/candle-examples/examples/segment-anything/main.rs +++ b/candle-examples/examples/segment-anything/main.rs @@ -27,12 +27,19 @@ struct Args { #[arg(long)] generate_masks: bool, + /// The target point x coordinate, between 0 and 1 (0.5 is at the middle of the image). #[arg(long, default_value_t = 0.5)] point_x: f64, + /// The target point y coordinate, between 0 and 1 (0.5 is at the middle of the image). #[arg(long, default_value_t = 0.5)] point_y: f64, + /// The detection threshold for the mask, 0 is the default value, negative values mean a larger + /// mask, positive makes the mask more selective. + #[arg(long, default_value_t = 0.)] + threshold: f32, + /// Enable tracing (generates a trace-timestamp.json file). #[arg(long)] tracing: bool, @@ -57,28 +64,9 @@ pub fn main() -> anyhow::Result<()> { let device = candle_examples::device(args.cpu)?; - let (image, initial_h, initial_w) = if args.image.ends_with(".safetensors") { - let mut tensors = candle::safetensors::load(&args.image, &device)?; - let image = match tensors.remove("image") { - Some(image) => image, - None => { - if tensors.len() != 1 { - anyhow::bail!("multiple tensors in '{}'", args.image) - } - tensors.into_values().next().unwrap() - } - }; - let image = if image.rank() == 4 { - image.get(0)? - } else { - image - }; - let (_c, h, w) = image.dims3()?; - (image, h, w) - } else { - let (image, h, w) = candle_examples::load_image(&args.image, Some(sam::IMAGE_SIZE))?; - (image.to_device(&device)?, h, w) - }; + let (image, initial_h, initial_w) = + candle_examples::load_image(&args.image, Some(sam::IMAGE_SIZE))?; + let image = image.to_device(&device)?; println!("loaded image {image:?}"); let model = match args.model { @@ -113,7 +101,7 @@ pub fn main() -> anyhow::Result<()> { /* crop_n_points_downscale_factor */ 1, )?; for (idx, bbox) in bboxes.iter().enumerate() { - println!("{bbox:?}"); + println!("{idx} {bbox:?}"); let mask = (&bbox.data.to_dtype(DType::U8)? * 255.)?; let (h, w) = mask.dims2()?; let mask = mask.broadcast_as((3, h, w))?; @@ -135,56 +123,42 @@ pub fn main() -> anyhow::Result<()> { println!("mask:\n{mask}"); println!("iou_predictions: {iou_predictions:?}"); - // Save the mask as an image. - let mask = (mask.ge(0f32)? * 255.)?; + let mask = (mask.ge(args.threshold)? * 255.)?; let (_one, h, w) = mask.dims3()?; let mask = mask.expand((3, h, w))?; - candle_examples::save_image_resize(&mask, "sam_mask.png", initial_h, initial_w)?; - - if !args.image.ends_with(".safetensors") { - let mut img = image::io::Reader::open(&args.image)? - .decode() - .map_err(candle::Error::wrap)?; - let mask_pixels = mask.permute((1, 2, 0))?.flatten_all()?.to_vec1::<u8>()?; - let mask_img: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> = - match image::ImageBuffer::from_raw(w as u32, h as u32, mask_pixels) { - Some(image) => image, - None => anyhow::bail!("error saving merged image"), - }; - let mask_img = image::DynamicImage::from(mask_img).resize_to_fill( - img.width(), - img.height(), - image::imageops::FilterType::CatmullRom, - ); - for x in 0..img.width() { - for y in 0..img.height() { - let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_img, x, y); - if mask_p.0[0] > 100 { - let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, x, y); - img_p.0[2] = 255 - (255 - img_p.0[2]) / 2; - img_p.0[1] /= 2; - img_p.0[0] /= 2; - imageproc::drawing::Canvas::draw_pixel(&mut img, x, y, img_p) - } + + let mut img = image::io::Reader::open(&args.image)? + .decode() + .map_err(candle::Error::wrap)?; + let mask_pixels = mask.permute((1, 2, 0))?.flatten_all()?.to_vec1::<u8>()?; + let mask_img: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> = + match image::ImageBuffer::from_raw(w as u32, h as u32, mask_pixels) { + Some(image) => image, + None => anyhow::bail!("error saving merged image"), + }; + let mask_img = image::DynamicImage::from(mask_img).resize_to_fill( + img.width(), + img.height(), + image::imageops::FilterType::CatmullRom, + ); + for x in 0..img.width() { + for y in 0..img.height() { + let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_img, x, y); + if mask_p.0[0] > 100 { + let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, x, y); + img_p.0[2] = 255 - (255 - img_p.0[2]) / 2; + img_p.0[1] /= 2; + img_p.0[0] /= 2; + imageproc::drawing::Canvas::draw_pixel(&mut img, x, y, img_p) } } - match point { - Some((x, y)) => { - let (x, y) = ( - (x * img.width() as f64) as i32, - (y * img.height() as f64) as i32, - ); - imageproc::drawing::draw_filled_circle( - &img, - (x, y), - 3, - image::Rgba([255, 0, 0, 200]), - ) - .save("sam_merged.jpg")? - } - None => img.save("sam_merged.jpg")?, - }; } + let (x, y) = ( + (args.point_x * img.width() as f64) as i32, + (args.point_y * img.height() as f64) as i32, + ); + imageproc::drawing::draw_filled_circle(&img, (x, y), 3, image::Rgba([255, 0, 0, 200])) + .save("sam_merged.jpg")? } Ok(()) } diff --git a/candle-examples/examples/yolo-v8/assets/peoples.pp.jpg b/candle-examples/examples/yolo-v8/assets/peoples.pp.jpg Binary files differdeleted file mode 100644 index 1707dbfa..00000000 --- a/candle-examples/examples/yolo-v8/assets/peoples.pp.jpg +++ /dev/null |