diff options
author | Jani Monoses <jani.monoses@gmail.com> | 2024-07-09 14:52:20 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-07-09 13:52:20 +0200 |
commit | a226a9736baee550b01de53cb3e416d3d94e69d3 (patch) | |
tree | 180c92bd3503350f48c281642b85f801e65fdb03 /candle-examples/src/imagenet.rs | |
parent | 25960676caefcb10060fb36a8d66efa9fa731dec (diff) | |
download | candle-a226a9736baee550b01de53cb3e416d3d94e69d3.tar.gz candle-a226a9736baee550b01de53cb3e416d3d94e69d3.tar.bz2 candle-a226a9736baee550b01de53cb3e416d3d94e69d3.zip |
Add Mobilenet v4 (#2325)
* Support different resolutions in load_image()
* Added MobilenetV4 model.
* Add MobileNetv4 to README
Diffstat (limited to 'candle-examples/src/imagenet.rs')
-rw-r--r-- | candle-examples/src/imagenet.rs | 30 |
1 files changed, 13 insertions, 17 deletions
diff --git a/candle-examples/src/imagenet.rs b/candle-examples/src/imagenet.rs index 781dcd4f..6b079870 100644 --- a/candle-examples/src/imagenet.rs +++ b/candle-examples/src/imagenet.rs @@ -1,15 +1,16 @@ use candle::{Device, Result, Tensor}; -/// Loads an image from disk using the image crate, this returns a tensor with shape -/// (3, 224, 224). imagenet normalization is applied. -pub fn load_image224<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> { +/// Loads an image from disk using the image crate at the requested resolution. +// This returns a tensor with shape (3, res, res). imagenet normalization is applied. +pub fn load_image<P: AsRef<std::path::Path>>(p: P, res: u32) -> Result<Tensor> { let img = image::io::Reader::open(p)? .decode() .map_err(candle::Error::wrap)? - .resize_to_fill(224, 224, image::imageops::FilterType::Triangle); + .resize_to_fill(res, res, image::imageops::FilterType::Triangle); let img = img.to_rgb8(); let data = img.into_raw(); - let data = Tensor::from_vec(data, (224, 224, 3), &Device::Cpu)?.permute((2, 0, 1))?; + let data = Tensor::from_vec(data, (res as usize, res as usize, 3), &Device::Cpu)? + .permute((2, 0, 1))?; let mean = Tensor::new(&[0.485f32, 0.456, 0.406], &Device::Cpu)?.reshape((3, 1, 1))?; let std = Tensor::new(&[0.229f32, 0.224, 0.225], &Device::Cpu)?.reshape((3, 1, 1))?; (data.to_dtype(candle::DType::F32)? / 255.)? @@ -18,21 +19,16 @@ pub fn load_image224<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> { } /// Loads an image from disk using the image crate, this returns a tensor with shape +/// (3, 224, 224). imagenet normalization is applied. +pub fn load_image224<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> { + load_image(p, 224) +} + +/// Loads an image from disk using the image crate, this returns a tensor with shape /// (3, 518, 518). imagenet normalization is applied. /// The model dinov2 reg4 analyzes images with dimensions 3x518x518 (resulting in 37x37 transformer tokens). pub fn load_image518<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> { - let img = image::io::Reader::open(p)? - .decode() - .map_err(candle::Error::wrap)? - .resize_to_fill(518, 518, image::imageops::FilterType::Triangle); - let img = img.to_rgb8(); - let data = img.into_raw(); - let data = Tensor::from_vec(data, (518, 518, 3), &Device::Cpu)?.permute((2, 0, 1))?; - let mean = Tensor::new(&[0.485f32, 0.456, 0.406], &Device::Cpu)?.reshape((3, 1, 1))?; - let std = Tensor::new(&[0.229f32, 0.224, 0.225], &Device::Cpu)?.reshape((3, 1, 1))?; - (data.to_dtype(candle::DType::F32)? / 255.)? - .broadcast_sub(&mean)? - .broadcast_div(&std) + load_image(p, 518) } pub const CLASS_COUNT: i64 = 1000; |