summaryrefslogtreecommitdiff
path: root/candle-examples/src/imagenet.rs
diff options
context:
space:
mode:
authorJani Monoses <jani.monoses@gmail.com>2024-07-09 14:52:20 +0300
committerGitHub <noreply@github.com>2024-07-09 13:52:20 +0200
commita226a9736baee550b01de53cb3e416d3d94e69d3 (patch)
tree180c92bd3503350f48c281642b85f801e65fdb03 /candle-examples/src/imagenet.rs
parent25960676caefcb10060fb36a8d66efa9fa731dec (diff)
downloadcandle-a226a9736baee550b01de53cb3e416d3d94e69d3.tar.gz
candle-a226a9736baee550b01de53cb3e416d3d94e69d3.tar.bz2
candle-a226a9736baee550b01de53cb3e416d3d94e69d3.zip
Add Mobilenet v4 (#2325)
* Support different resolutions in load_image() * Added MobilenetV4 model. * Add MobileNetv4 to README
Diffstat (limited to 'candle-examples/src/imagenet.rs')
-rw-r--r--candle-examples/src/imagenet.rs30
1 files changed, 13 insertions, 17 deletions
diff --git a/candle-examples/src/imagenet.rs b/candle-examples/src/imagenet.rs
index 781dcd4f..6b079870 100644
--- a/candle-examples/src/imagenet.rs
+++ b/candle-examples/src/imagenet.rs
@@ -1,15 +1,16 @@
use candle::{Device, Result, Tensor};
-/// Loads an image from disk using the image crate, this returns a tensor with shape
-/// (3, 224, 224). imagenet normalization is applied.
-pub fn load_image224<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
+/// Loads an image from disk using the image crate at the requested resolution.
+// This returns a tensor with shape (3, res, res). imagenet normalization is applied.
+pub fn load_image<P: AsRef<std::path::Path>>(p: P, res: u32) -> Result<Tensor> {
let img = image::io::Reader::open(p)?
.decode()
.map_err(candle::Error::wrap)?
- .resize_to_fill(224, 224, image::imageops::FilterType::Triangle);
+ .resize_to_fill(res, res, image::imageops::FilterType::Triangle);
let img = img.to_rgb8();
let data = img.into_raw();
- let data = Tensor::from_vec(data, (224, 224, 3), &Device::Cpu)?.permute((2, 0, 1))?;
+ let data = Tensor::from_vec(data, (res as usize, res as usize, 3), &Device::Cpu)?
+ .permute((2, 0, 1))?;
let mean = Tensor::new(&[0.485f32, 0.456, 0.406], &Device::Cpu)?.reshape((3, 1, 1))?;
let std = Tensor::new(&[0.229f32, 0.224, 0.225], &Device::Cpu)?.reshape((3, 1, 1))?;
(data.to_dtype(candle::DType::F32)? / 255.)?
@@ -18,21 +19,16 @@ pub fn load_image224<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
}
/// Loads an image from disk using the image crate, this returns a tensor with shape
+/// (3, 224, 224). imagenet normalization is applied.
+pub fn load_image224<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
+ load_image(p, 224)
+}
+
+/// Loads an image from disk using the image crate, this returns a tensor with shape
/// (3, 518, 518). imagenet normalization is applied.
/// The model dinov2 reg4 analyzes images with dimensions 3x518x518 (resulting in 37x37 transformer tokens).
pub fn load_image518<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
- let img = image::io::Reader::open(p)?
- .decode()
- .map_err(candle::Error::wrap)?
- .resize_to_fill(518, 518, image::imageops::FilterType::Triangle);
- let img = img.to_rgb8();
- let data = img.into_raw();
- let data = Tensor::from_vec(data, (518, 518, 3), &Device::Cpu)?.permute((2, 0, 1))?;
- let mean = Tensor::new(&[0.485f32, 0.456, 0.406], &Device::Cpu)?.reshape((3, 1, 1))?;
- let std = Tensor::new(&[0.229f32, 0.224, 0.225], &Device::Cpu)?.reshape((3, 1, 1))?;
- (data.to_dtype(candle::DType::F32)? / 255.)?
- .broadcast_sub(&mean)?
- .broadcast_div(&std)
+ load_image(p, 518)
}
pub const CLASS_COUNT: i64 = 1000;