diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-07 21:45:16 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-07 21:45:16 +0100 |
commit | 79c27fc489f2eece486fa433a0ae75c66a398e6f (patch) | |
tree | c86b89b2c1270cf3207c5242de27c6b069652044 | |
parent | 7396b8ed1a5394c58fcc772e5f6e6038577505b8 (diff) | |
download | candle-79c27fc489f2eece486fa433a0ae75c66a398e6f.tar.gz candle-79c27fc489f2eece486fa433a0ae75c66a398e6f.tar.bz2 candle-79c27fc489f2eece486fa433a0ae75c66a398e6f.zip |
Segment-anything fixes: avoid normalizing twice. (#767)
* Segment-anything fixes: avoid normalizing twice.
* More fixes for the image aspect ratio.
-rw-r--r-- | candle-examples/examples/segment-anything/main.rs | 5 | ||||
-rw-r--r-- | candle-examples/examples/segment-anything/model_sam.rs | 3 | ||||
-rw-r--r-- | candle-examples/src/lib.rs | 28 |
3 files changed, 33 insertions, 3 deletions
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs index a2722270..03ebe346 100644 --- a/candle-examples/examples/segment-anything/main.rs +++ b/candle-examples/examples/segment-anything/main.rs @@ -108,7 +108,8 @@ pub fn main() -> anyhow::Result<()> { let device = candle_examples::device(args.cpu)?; - let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?; + let image = + candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?.to_device(&device)?; println!("loaded image {image:?}"); let model = match args.model { @@ -125,7 +126,7 @@ pub fn main() -> anyhow::Result<()> { let sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b let (mask, iou_predictions) = sam.forward(&image, false)?; - println!("mask: {mask:?}"); + println!("mask:\n{mask}"); println!("iou_predictions: {iou_predictions:?}"); Ok(()) } diff --git a/candle-examples/examples/segment-anything/model_sam.rs b/candle-examples/examples/segment-anything/model_sam.rs index 1c8e9a59..acba7ef4 100644 --- a/candle-examples/examples/segment-anything/model_sam.rs +++ b/candle-examples/examples/segment-anything/model_sam.rs @@ -6,7 +6,7 @@ use crate::model_mask_decoder::MaskDecoder; use crate::model_prompt_encoder::PromptEncoder; const PROMPT_EMBED_DIM: usize = 256; -const IMAGE_SIZE: usize = 1024; +pub const IMAGE_SIZE: usize = 1024; const VIT_PATCH_SIZE: usize = 16; #[derive(Debug)] @@ -90,6 +90,7 @@ impl Sam { fn preprocess(&self, img: &Tensor) -> Result<Tensor> { let (c, h, w) = img.dims3()?; let img = img + .to_dtype(DType::F32)? .broadcast_sub(&self.pixel_mean)? .broadcast_div(&self.pixel_std)?; if h > IMAGE_SIZE || w > IMAGE_SIZE { diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs index f9581b02..66cd2f99 100644 --- a/candle-examples/src/lib.rs +++ b/candle-examples/src/lib.rs @@ -16,6 +16,34 @@ pub fn device(cpu: bool) -> Result<Device> { } } +pub fn load_image<P: AsRef<std::path::Path>>( + p: P, + resize_longest: Option<usize>, +) -> Result<Tensor> { + let img = image::io::Reader::open(p)? + .decode() + .map_err(candle::Error::wrap)?; + let img = match resize_longest { + None => img, + Some(resize_longest) => { + let (height, width) = (img.height(), img.width()); + let resize_longest = resize_longest as u32; + let (height, width) = if height < width { + let h = (resize_longest * height) / width; + (h, resize_longest) + } else { + let w = (resize_longest * width) / height; + (resize_longest, w) + }; + img.resize_exact(width, height, image::imageops::FilterType::CatmullRom) + } + }; + let (height, width) = (img.height() as usize, img.width() as usize); + let img = img.to_rgb8(); + let data = img.into_raw(); + Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1)) +} + pub fn load_image_and_resize<P: AsRef<std::path::Path>>( p: P, width: usize, |