summaryrefslogtreecommitdiff
path: root/candle-examples/src/imagenet.rs
diff options
context:
space:
mode:
authorJani Monoses <jani.monoses@gmail.com>2024-08-29 16:38:58 +0300
committerGitHub <noreply@github.com>2024-08-29 15:38:58 +0200
commit86613c00e216750f32a326dbff5cc993d5e0067e (patch)
tree3b51fc0d3a1a930b57a86d6f1fd94041f68187cf /candle-examples/src/imagenet.rs
parent29e25c458da869808502a88024f57b1c86efa090 (diff)
downloadcandle-86613c00e216750f32a326dbff5cc993d5e0067e.tar.gz
candle-86613c00e216750f32a326dbff5cc993d5e0067e.tar.bz2
candle-86613c00e216750f32a326dbff5cc993d5e0067e.zip
MobileCLIP models S1 and S2 (#2454)
* Allow loading images with given std and mean * OpenCLIP text encoder component * Two MobileCLIP models * Clippy fixes. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-examples/src/imagenet.rs')
-rw-r--r--candle-examples/src/imagenet.rs35
1 files changed, 27 insertions, 8 deletions
diff --git a/candle-examples/src/imagenet.rs b/candle-examples/src/imagenet.rs
index 6fcda424..a3b12423 100644
--- a/candle-examples/src/imagenet.rs
+++ b/candle-examples/src/imagenet.rs
@@ -1,23 +1,42 @@
use candle::{Device, Result, Tensor};
-/// Loads an image from disk using the image crate at the requested resolution.
-// This returns a tensor with shape (3, res, res). imagenet normalization is applied.
-pub fn load_image<P: AsRef<std::path::Path>>(p: P, res: u32) -> Result<Tensor> {
+pub const IMAGENET_MEAN: [f32; 3] = [0.485f32, 0.456, 0.406];
+pub const IMAGENET_STD: [f32; 3] = [0.229f32, 0.224, 0.225];
+
+/// Loads an image from disk using the image crate at the requested resolution,
+/// using the given std and mean parameters.
+/// This returns a tensor with shape (3, res, res). imagenet normalization is applied.
+
+pub fn load_image_with_std_mean<P: AsRef<std::path::Path>>(
+ p: P,
+ res: usize,
+ mean: &[f32; 3],
+ std: &[f32; 3],
+) -> Result<Tensor> {
let img = image::ImageReader::open(p)?
.decode()
.map_err(candle::Error::wrap)?
- .resize_to_fill(res, res, image::imageops::FilterType::Triangle);
+ .resize_to_fill(
+ res as u32,
+ res as u32,
+ image::imageops::FilterType::Triangle,
+ );
let img = img.to_rgb8();
let data = img.into_raw();
- let data = Tensor::from_vec(data, (res as usize, res as usize, 3), &Device::Cpu)?
- .permute((2, 0, 1))?;
- let mean = Tensor::new(&[0.485f32, 0.456, 0.406], &Device::Cpu)?.reshape((3, 1, 1))?;
- let std = Tensor::new(&[0.229f32, 0.224, 0.225], &Device::Cpu)?.reshape((3, 1, 1))?;
+ let data = Tensor::from_vec(data, (res, res, 3), &Device::Cpu)?.permute((2, 0, 1))?;
+ let mean = Tensor::new(mean, &Device::Cpu)?.reshape((3, 1, 1))?;
+ let std = Tensor::new(std, &Device::Cpu)?.reshape((3, 1, 1))?;
(data.to_dtype(candle::DType::F32)? / 255.)?
.broadcast_sub(&mean)?
.broadcast_div(&std)
}
+/// Loads an image from disk using the image crate at the requested resolution.
+/// This returns a tensor with shape (3, res, res). imagenet normalization is applied.
+pub fn load_image<P: AsRef<std::path::Path>>(p: P, res: usize) -> Result<Tensor> {
+ load_image_with_std_mean(p, res, &IMAGENET_MEAN, &IMAGENET_STD)
+}
+
/// Loads an image from disk using the image crate, this returns a tensor with shape
/// (3, 224, 224). imagenet normalization is applied.
pub fn load_image224<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {