summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-12 14:35:55 +0100
committerGitHub <noreply@github.com>2023-09-12 14:35:55 +0100
commit42da17694a4214a3e39e0d64afc22635ce83f557 (patch)
treed996a2dd9cbe3eb4aef17dc26fccddc957326553
parent25aacda28eff79d14be09d636275b0351d81069d (diff)
downloadcandle-42da17694a4214a3e39e0d64afc22635ce83f557.tar.gz
candle-42da17694a4214a3e39e0d64afc22635ce83f557.tar.bz2
candle-42da17694a4214a3e39e0d64afc22635ce83f557.zip
Segment Anything readme (#827)
* Add a readme for the segment-anything model. * Add the original image. * Clean-up the segment anything cli example. * Also print the mask id in the outputs.
-rw-r--r--candle-examples/examples/segment-anything/README.md40
-rw-r--r--candle-examples/examples/segment-anything/assets/sam_merged.jpgbin0 -> 160984 bytes
-rw-r--r--candle-examples/examples/segment-anything/main.rs110
-rw-r--r--candle-examples/examples/yolo-v8/assets/peoples.pp.jpgbin81845 -> 0 bytes
4 files changed, 82 insertions, 68 deletions
diff --git a/candle-examples/examples/segment-anything/README.md b/candle-examples/examples/segment-anything/README.md
new file mode 100644
index 00000000..3c5b034f
--- /dev/null
+++ b/candle-examples/examples/segment-anything/README.md
@@ -0,0 +1,40 @@
+# candle-segment-anything: Segment-Anything Model
+
+This example is based on Meta AI [Segment-Anything
+Model](https://github.com/facebookresearch/segment-anything). This model
+provides a robust and fast image segmentation pipeline that can be tweaked via
+some prompting (requesting some points to be in the target mask, requesting some
+points to be part of the background so _not_ in the target mask, specifying some
+bounding box).
+
+The default backbone can be replaced by the smaller and faster TinyViT model
+based on [MobileSAM](https://github.com/ChaoningZhang/MobileSAM).
+
+## Running some example.
+
+```bash
+cargo run --example segment-anything --release -- \
+ --image candle-examples/examples/yolo-v8/assets/bike.jpg
+ --use-tiny
+ --point-x 0.4
+ --point-y 0.3
+```
+
+Running this command generates a `sam_merged.jpg` file containing the original
+image with a blue overlay of the selected mask. The red dot represents the prompt
+specified by `--point-x 0.4 --point-y 0.3`, this prompt is assumed to be part
+of the target mask.
+
+The values used for `--point-x` and `--point-y` should be between 0 and 1 and
+are proportional to the image dimension, i.e. use 0.5 for the image center.
+
+![Leading group, Giro d'Italia 2021](../yolo-v8/assets/bike.jpg)
+
+![Leading group, Giro d'Italia 2021](./assets/sam_merged.jpg)
+
+### Command-line flags
+- `--use-tiny`: use the TinyViT based MobileSAM backbone rather than the default
+ one.
+- `--point-x`, `--point-y`: specifies the location of the target point.
+- `--threshold`: sets the threshold value to be part of the mask, a negative
+ value results in a larger mask and can be specified via `--threshold=-1.2`.
diff --git a/candle-examples/examples/segment-anything/assets/sam_merged.jpg b/candle-examples/examples/segment-anything/assets/sam_merged.jpg
new file mode 100644
index 00000000..a5f64e5e
--- /dev/null
+++ b/candle-examples/examples/segment-anything/assets/sam_merged.jpg
Binary files differ
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs
index 21ba0415..3d9898b6 100644
--- a/candle-examples/examples/segment-anything/main.rs
+++ b/candle-examples/examples/segment-anything/main.rs
@@ -27,12 +27,19 @@ struct Args {
#[arg(long)]
generate_masks: bool,
+ /// The target point x coordinate, between 0 and 1 (0.5 is at the middle of the image).
#[arg(long, default_value_t = 0.5)]
point_x: f64,
+ /// The target point y coordinate, between 0 and 1 (0.5 is at the middle of the image).
#[arg(long, default_value_t = 0.5)]
point_y: f64,
+ /// The detection threshold for the mask, 0 is the default value, negative values mean a larger
+ /// mask, positive makes the mask more selective.
+ #[arg(long, default_value_t = 0.)]
+ threshold: f32,
+
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
@@ -57,28 +64,9 @@ pub fn main() -> anyhow::Result<()> {
let device = candle_examples::device(args.cpu)?;
- let (image, initial_h, initial_w) = if args.image.ends_with(".safetensors") {
- let mut tensors = candle::safetensors::load(&args.image, &device)?;
- let image = match tensors.remove("image") {
- Some(image) => image,
- None => {
- if tensors.len() != 1 {
- anyhow::bail!("multiple tensors in '{}'", args.image)
- }
- tensors.into_values().next().unwrap()
- }
- };
- let image = if image.rank() == 4 {
- image.get(0)?
- } else {
- image
- };
- let (_c, h, w) = image.dims3()?;
- (image, h, w)
- } else {
- let (image, h, w) = candle_examples::load_image(&args.image, Some(sam::IMAGE_SIZE))?;
- (image.to_device(&device)?, h, w)
- };
+ let (image, initial_h, initial_w) =
+ candle_examples::load_image(&args.image, Some(sam::IMAGE_SIZE))?;
+ let image = image.to_device(&device)?;
println!("loaded image {image:?}");
let model = match args.model {
@@ -113,7 +101,7 @@ pub fn main() -> anyhow::Result<()> {
/* crop_n_points_downscale_factor */ 1,
)?;
for (idx, bbox) in bboxes.iter().enumerate() {
- println!("{bbox:?}");
+ println!("{idx} {bbox:?}");
let mask = (&bbox.data.to_dtype(DType::U8)? * 255.)?;
let (h, w) = mask.dims2()?;
let mask = mask.broadcast_as((3, h, w))?;
@@ -135,56 +123,42 @@ pub fn main() -> anyhow::Result<()> {
println!("mask:\n{mask}");
println!("iou_predictions: {iou_predictions:?}");
- // Save the mask as an image.
- let mask = (mask.ge(0f32)? * 255.)?;
+ let mask = (mask.ge(args.threshold)? * 255.)?;
let (_one, h, w) = mask.dims3()?;
let mask = mask.expand((3, h, w))?;
- candle_examples::save_image_resize(&mask, "sam_mask.png", initial_h, initial_w)?;
-
- if !args.image.ends_with(".safetensors") {
- let mut img = image::io::Reader::open(&args.image)?
- .decode()
- .map_err(candle::Error::wrap)?;
- let mask_pixels = mask.permute((1, 2, 0))?.flatten_all()?.to_vec1::<u8>()?;
- let mask_img: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
- match image::ImageBuffer::from_raw(w as u32, h as u32, mask_pixels) {
- Some(image) => image,
- None => anyhow::bail!("error saving merged image"),
- };
- let mask_img = image::DynamicImage::from(mask_img).resize_to_fill(
- img.width(),
- img.height(),
- image::imageops::FilterType::CatmullRom,
- );
- for x in 0..img.width() {
- for y in 0..img.height() {
- let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_img, x, y);
- if mask_p.0[0] > 100 {
- let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, x, y);
- img_p.0[2] = 255 - (255 - img_p.0[2]) / 2;
- img_p.0[1] /= 2;
- img_p.0[0] /= 2;
- imageproc::drawing::Canvas::draw_pixel(&mut img, x, y, img_p)
- }
+
+ let mut img = image::io::Reader::open(&args.image)?
+ .decode()
+ .map_err(candle::Error::wrap)?;
+ let mask_pixels = mask.permute((1, 2, 0))?.flatten_all()?.to_vec1::<u8>()?;
+ let mask_img: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
+ match image::ImageBuffer::from_raw(w as u32, h as u32, mask_pixels) {
+ Some(image) => image,
+ None => anyhow::bail!("error saving merged image"),
+ };
+ let mask_img = image::DynamicImage::from(mask_img).resize_to_fill(
+ img.width(),
+ img.height(),
+ image::imageops::FilterType::CatmullRom,
+ );
+ for x in 0..img.width() {
+ for y in 0..img.height() {
+ let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_img, x, y);
+ if mask_p.0[0] > 100 {
+ let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, x, y);
+ img_p.0[2] = 255 - (255 - img_p.0[2]) / 2;
+ img_p.0[1] /= 2;
+ img_p.0[0] /= 2;
+ imageproc::drawing::Canvas::draw_pixel(&mut img, x, y, img_p)
}
}
- match point {
- Some((x, y)) => {
- let (x, y) = (
- (x * img.width() as f64) as i32,
- (y * img.height() as f64) as i32,
- );
- imageproc::drawing::draw_filled_circle(
- &img,
- (x, y),
- 3,
- image::Rgba([255, 0, 0, 200]),
- )
- .save("sam_merged.jpg")?
- }
- None => img.save("sam_merged.jpg")?,
- };
}
+ let (x, y) = (
+ (args.point_x * img.width() as f64) as i32,
+ (args.point_y * img.height() as f64) as i32,
+ );
+ imageproc::drawing::draw_filled_circle(&img, (x, y), 3, image::Rgba([255, 0, 0, 200]))
+ .save("sam_merged.jpg")?
}
Ok(())
}
diff --git a/candle-examples/examples/yolo-v8/assets/peoples.pp.jpg b/candle-examples/examples/yolo-v8/assets/peoples.pp.jpg
deleted file mode 100644
index 1707dbfa..00000000
--- a/candle-examples/examples/yolo-v8/assets/peoples.pp.jpg
+++ /dev/null
Binary files differ