From 79c27fc489f2eece486fa433a0ae75c66a398e6f Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 7 Sep 2023 21:45:16 +0100 Subject: Segment-anything fixes: avoid normalizing twice. (#767) * Segment-anything fixes: avoid normalizing twice. * More fixes for the image aspect ratio. --- candle-examples/src/lib.rs | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) (limited to 'candle-examples/src') diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs index f9581b02..66cd2f99 100644 --- a/candle-examples/src/lib.rs +++ b/candle-examples/src/lib.rs @@ -16,6 +16,34 @@ pub fn device(cpu: bool) -> Result { } } +pub fn load_image>( + p: P, + resize_longest: Option, +) -> Result { + let img = image::io::Reader::open(p)? + .decode() + .map_err(candle::Error::wrap)?; + let img = match resize_longest { + None => img, + Some(resize_longest) => { + let (height, width) = (img.height(), img.width()); + let resize_longest = resize_longest as u32; + let (height, width) = if height < width { + let h = (resize_longest * height) / width; + (h, resize_longest) + } else { + let w = (resize_longest * width) / height; + (resize_longest, w) + }; + img.resize_exact(width, height, image::imageops::FilterType::CatmullRom) + } + }; + let (height, width) = (img.height() as usize, img.width() as usize); + let img = img.to_rgb8(); + let data = img.into_raw(); + Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1)) +} + pub fn load_image_and_resize>( p: P, width: usize, -- cgit v1.2.3