summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/segment-anything')
-rw-r--r--candle-examples/examples/segment-anything/main.rs59
1 files changed, 50 insertions, 9 deletions
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs
index 0f0c0482..c5095c0e 100644
--- a/candle-examples/examples/segment-anything/main.rs
+++ b/candle-examples/examples/segment-anything/main.rs
@@ -104,11 +104,11 @@ struct Args {
#[arg(long)]
generate_masks: bool,
- #[arg(long)]
- point_x: Option<f64>,
+ #[arg(long, default_value_t = 0.5)]
+ point_x: f64,
- #[arg(long)]
- point_y: Option<f64>,
+ #[arg(long, default_value_t = 0.5)]
+ point_y: f64,
}
pub fn main() -> anyhow::Result<()> {
@@ -135,7 +135,7 @@ pub fn main() -> anyhow::Result<()> {
let (_c, h, w) = image.dims3()?;
(image, h, w)
} else {
- let (image, h, w) = candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?;
+ let (image, h, w) = candle_examples::load_image(&args.image, Some(model_sam::IMAGE_SIZE))?;
(image.to_device(&device)?, h, w)
};
println!("loaded image {image:?}");
@@ -163,7 +163,7 @@ pub fn main() -> anyhow::Result<()> {
/* crop_n_points_downscale_factor */ 1,
)?
} else {
- let point = args.point_x.zip(args.point_y);
+ let point = Some((args.point_x, args.point_y));
let (mask, iou_predictions) = sam.forward(&image, point, false)?;
println!("mask:\n{mask}");
println!("iou_predictions: {iou_predictions:?}");
@@ -174,9 +174,50 @@ pub fn main() -> anyhow::Result<()> {
let mask = mask.expand((3, h, w))?;
candle_examples::save_image_resize(&mask, "sam_mask.png", initial_h, initial_w)?;
- let image = sam.preprocess(&image)?;
- let image = sam.unpreprocess(&image)?.to_dtype(DType::U8)?;
- candle_examples::save_image(&image, "sam_input_scaled.png")?;
+ if !args.image.ends_with(".safetensors") {
+ let mut img = image::io::Reader::open(&args.image)?
+ .decode()
+ .map_err(candle::Error::wrap)?;
+ let mask_pixels = mask.permute((1, 2, 0))?.flatten_all()?.to_vec1::<u8>()?;
+ let mask_img: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
+ match image::ImageBuffer::from_raw(w as u32, h as u32, mask_pixels) {
+ Some(image) => image,
+ None => anyhow::bail!("error saving merged image"),
+ };
+ let mask_img = image::DynamicImage::from(mask_img).resize_to_fill(
+ img.width(),
+ img.height(),
+ image::imageops::FilterType::CatmullRom,
+ );
+ for x in 0..img.width() {
+ for y in 0..img.height() {
+ let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_img, x, y);
+ if mask_p.0[0] > 100 {
+ let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, x, y);
+ img_p.0[2] = 255 - (255 - img_p.0[2]) / 2;
+ img_p.0[1] /= 2;
+ img_p.0[0] /= 2;
+ imageproc::drawing::Canvas::draw_pixel(&mut img, x, y, img_p)
+ }
+ }
+ }
+ match point {
+ Some((x, y)) => {
+ let (x, y) = (
+ (x * img.width() as f64) as i32,
+ (y * img.height() as f64) as i32,
+ );
+ imageproc::drawing::draw_filled_circle(
+ &img,
+ (x, y),
+ 3,
+ image::Rgba([255, 0, 0, 200]),
+ )
+ .save("sam_merged.jpg")?
+ }
+ None => img.save("sam_merged.jpg")?,
+ };
+ }
}
Ok(())
}