diff options
author | Jani Monoses <jani.monoses@gmail.com> | 2024-08-29 16:38:58 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-08-29 15:38:58 +0200 |
commit | 86613c00e216750f32a326dbff5cc993d5e0067e (patch) | |
tree | 3b51fc0d3a1a930b57a86d6f1fd94041f68187cf /candle-examples/src/imagenet.rs | |
parent | 29e25c458da869808502a88024f57b1c86efa090 (diff) | |
download | candle-86613c00e216750f32a326dbff5cc993d5e0067e.tar.gz candle-86613c00e216750f32a326dbff5cc993d5e0067e.tar.bz2 candle-86613c00e216750f32a326dbff5cc993d5e0067e.zip |
MobileCLIP models S1 and S2 (#2454)
* Allow loading images with given std and mean
* OpenCLIP text encoder component
* Two MobileCLIP models
* Clippy fixes.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-examples/src/imagenet.rs')
-rw-r--r-- | candle-examples/src/imagenet.rs | 35 |
1 files changed, 27 insertions, 8 deletions
diff --git a/candle-examples/src/imagenet.rs b/candle-examples/src/imagenet.rs index 6fcda424..a3b12423 100644 --- a/candle-examples/src/imagenet.rs +++ b/candle-examples/src/imagenet.rs @@ -1,23 +1,42 @@ use candle::{Device, Result, Tensor}; -/// Loads an image from disk using the image crate at the requested resolution. -// This returns a tensor with shape (3, res, res). imagenet normalization is applied. -pub fn load_image<P: AsRef<std::path::Path>>(p: P, res: u32) -> Result<Tensor> { +pub const IMAGENET_MEAN: [f32; 3] = [0.485f32, 0.456, 0.406]; +pub const IMAGENET_STD: [f32; 3] = [0.229f32, 0.224, 0.225]; + +/// Loads an image from disk using the image crate at the requested resolution, +/// using the given std and mean parameters. +/// This returns a tensor with shape (3, res, res). imagenet normalization is applied. + +pub fn load_image_with_std_mean<P: AsRef<std::path::Path>>( + p: P, + res: usize, + mean: &[f32; 3], + std: &[f32; 3], +) -> Result<Tensor> { let img = image::ImageReader::open(p)? .decode() .map_err(candle::Error::wrap)? - .resize_to_fill(res, res, image::imageops::FilterType::Triangle); + .resize_to_fill( + res as u32, + res as u32, + image::imageops::FilterType::Triangle, + ); let img = img.to_rgb8(); let data = img.into_raw(); - let data = Tensor::from_vec(data, (res as usize, res as usize, 3), &Device::Cpu)? - .permute((2, 0, 1))?; - let mean = Tensor::new(&[0.485f32, 0.456, 0.406], &Device::Cpu)?.reshape((3, 1, 1))?; - let std = Tensor::new(&[0.229f32, 0.224, 0.225], &Device::Cpu)?.reshape((3, 1, 1))?; + let data = Tensor::from_vec(data, (res, res, 3), &Device::Cpu)?.permute((2, 0, 1))?; + let mean = Tensor::new(mean, &Device::Cpu)?.reshape((3, 1, 1))?; + let std = Tensor::new(std, &Device::Cpu)?.reshape((3, 1, 1))?; (data.to_dtype(candle::DType::F32)? / 255.)? .broadcast_sub(&mean)? .broadcast_div(&std) } +/// Loads an image from disk using the image crate at the requested resolution. +/// This returns a tensor with shape (3, res, res). imagenet normalization is applied. +pub fn load_image<P: AsRef<std::path::Path>>(p: P, res: usize) -> Result<Tensor> { + load_image_with_std_mean(p, res, &IMAGENET_MEAN, &IMAGENET_STD) +} + /// Loads an image from disk using the image crate, this returns a tensor with shape /// (3, 224, 224). imagenet normalization is applied. pub fn load_image224<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> { |