summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-07 21:45:16 +0100
committerGitHub <noreply@github.com>2023-09-07 21:45:16 +0100
commit79c27fc489f2eece486fa433a0ae75c66a398e6f (patch)
treec86b89b2c1270cf3207c5242de27c6b069652044
parent7396b8ed1a5394c58fcc772e5f6e6038577505b8 (diff)
downloadcandle-79c27fc489f2eece486fa433a0ae75c66a398e6f.tar.gz
candle-79c27fc489f2eece486fa433a0ae75c66a398e6f.tar.bz2
candle-79c27fc489f2eece486fa433a0ae75c66a398e6f.zip
Segment-anything fixes: avoid normalizing twice. (#767)
* Segment-anything fixes: avoid normalizing twice. * More fixes for the image aspect ratio.
-rw-r--r--candle-examples/examples/segment-anything/main.rs5
-rw-r--r--candle-examples/examples/segment-anything/model_sam.rs3
-rw-r--r--candle-examples/src/lib.rs28
3 files changed, 33 insertions, 3 deletions
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs
index a2722270..03ebe346 100644
--- a/candle-examples/examples/segment-anything/main.rs
+++ b/candle-examples/examples/segment-anything/main.rs
@@ -108,7 +108,8 @@ pub fn main() -> anyhow::Result<()> {
let device = candle_examples::device(args.cpu)?;
- let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
+ let image =
+ candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?.to_device(&device)?;
println!("loaded image {image:?}");
let model = match args.model {
@@ -125,7 +126,7 @@ pub fn main() -> anyhow::Result<()> {
let sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b
let (mask, iou_predictions) = sam.forward(&image, false)?;
- println!("mask: {mask:?}");
+ println!("mask:\n{mask}");
println!("iou_predictions: {iou_predictions:?}");
Ok(())
}
diff --git a/candle-examples/examples/segment-anything/model_sam.rs b/candle-examples/examples/segment-anything/model_sam.rs
index 1c8e9a59..acba7ef4 100644
--- a/candle-examples/examples/segment-anything/model_sam.rs
+++ b/candle-examples/examples/segment-anything/model_sam.rs
@@ -6,7 +6,7 @@ use crate::model_mask_decoder::MaskDecoder;
use crate::model_prompt_encoder::PromptEncoder;
const PROMPT_EMBED_DIM: usize = 256;
-const IMAGE_SIZE: usize = 1024;
+pub const IMAGE_SIZE: usize = 1024;
const VIT_PATCH_SIZE: usize = 16;
#[derive(Debug)]
@@ -90,6 +90,7 @@ impl Sam {
fn preprocess(&self, img: &Tensor) -> Result<Tensor> {
let (c, h, w) = img.dims3()?;
let img = img
+ .to_dtype(DType::F32)?
.broadcast_sub(&self.pixel_mean)?
.broadcast_div(&self.pixel_std)?;
if h > IMAGE_SIZE || w > IMAGE_SIZE {
diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs
index f9581b02..66cd2f99 100644
--- a/candle-examples/src/lib.rs
+++ b/candle-examples/src/lib.rs
@@ -16,6 +16,34 @@ pub fn device(cpu: bool) -> Result<Device> {
}
}
+pub fn load_image<P: AsRef<std::path::Path>>(
+ p: P,
+ resize_longest: Option<usize>,
+) -> Result<Tensor> {
+ let img = image::io::Reader::open(p)?
+ .decode()
+ .map_err(candle::Error::wrap)?;
+ let img = match resize_longest {
+ None => img,
+ Some(resize_longest) => {
+ let (height, width) = (img.height(), img.width());
+ let resize_longest = resize_longest as u32;
+ let (height, width) = if height < width {
+ let h = (resize_longest * height) / width;
+ (h, resize_longest)
+ } else {
+ let w = (resize_longest * width) / height;
+ (resize_longest, w)
+ };
+ img.resize_exact(width, height, image::imageops::FilterType::CatmullRom)
+ }
+ };
+ let (height, width) = (img.height() as usize, img.width() as usize);
+ let img = img.to_rgb8();
+ let data = img.into_raw();
+ Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))
+}
+
pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
p: P,
width: usize,