diff options
author | v-espitalier <125037408+v-espitalier@users.noreply.github.com> | 2024-07-07 20:09:31 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-07-07 20:09:31 +0200 |
commit | 9cd54aa5d4fb6cf30e0df2d198c8a387db2d9144 (patch) | |
tree | 9988e786128416cf8c77425658e29a67c904a5ad /candle-examples/examples/eva2 | |
parent | eec11ce2ce1e81f0fdb1cac5405d07286242dc01 (diff) | |
download | candle-9cd54aa5d4fb6cf30e0df2d198c8a387db2d9144.tar.gz candle-9cd54aa5d4fb6cf30e0df2d198c8a387db2d9144.tar.bz2 candle-9cd54aa5d4fb6cf30e0df2d198c8a387db2d9144.zip |
Add EVA-02 model ( https://arxiv.org/abs/2303.11331 ) (#2311)
* Add EVA-02 model ( https://arxiv.org/abs/2303.11331 )
* Clippy fix.
* And apply fmt.
---------
Co-authored-by: v-espitalier <>
Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-examples/examples/eva2')
-rw-r--r-- | candle-examples/examples/eva2/README.md | 21 | ||||
-rw-r--r-- | candle-examples/examples/eva2/main.rs | 82 |
2 files changed, 103 insertions, 0 deletions
diff --git a/candle-examples/examples/eva2/README.md b/candle-examples/examples/eva2/README.md new file mode 100644 index 00000000..10c91b89 --- /dev/null +++ b/candle-examples/examples/eva2/README.md @@ -0,0 +1,21 @@ +# candle-eva2 + +[EVA-02](https://arxiv.org/abs/2303.11331) is a computer vision model. +In this example, it is used as an ImageNet classifier: the model returns the +probability for the image to belong to each of the 1000 ImageNet categories. + +## Running some example + +```bash +cargo run --example eva2 --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg + +> mountain bike, all-terrain bike, off-roader: 37.09% +> maillot : 8.30% +> alp : 2.13% +> bicycle-built-for-two, tandem bicycle, tandem: 0.84% +> crash helmet : 0.73% + + +``` + + diff --git a/candle-examples/examples/eva2/main.rs b/candle-examples/examples/eva2/main.rs new file mode 100644 index 00000000..4270075d --- /dev/null +++ b/candle-examples/examples/eva2/main.rs @@ -0,0 +1,82 @@ +//! EVA-02: Explore the limits of Visual representation at scAle +//! https://github.com/baaivision/EVA + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use clap::Parser; + +use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle_nn::{Module, VarBuilder}; +use candle_transformers::models::eva2; + +/// Loads an image from disk using the image crate, this returns a tensor with shape +/// (3, 448, 448). OpenAI normalization is applied. +pub fn load_image448_openai_norm<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> { + let img = image::io::Reader::open(p)? + .decode() + .map_err(candle::Error::wrap)? + .resize_to_fill(448, 448, image::imageops::FilterType::Triangle); + let img = img.to_rgb8(); + let data = img.into_raw(); + let data = Tensor::from_vec(data, (448, 448, 3), &Device::Cpu)?.permute((2, 0, 1))?; + let mean = + Tensor::new(&[0.48145466f32, 0.4578275, 0.40821073], &Device::Cpu)?.reshape((3, 1, 1))?; + let std = Tensor::new(&[0.26862954f32, 0.261_302_6, 0.275_777_1], &Device::Cpu)? + .reshape((3, 1, 1))?; + (data.to_dtype(candle::DType::F32)? / 255.)? + .broadcast_sub(&mean)? + .broadcast_div(&std) +} + +#[derive(Parser)] +struct Args { + #[arg(long)] + model: Option<String>, + + #[arg(long)] + image: String, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, +} + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + + let device = candle_examples::device(args.cpu)?; + + let image = load_image448_openai_norm(args.image)?.to_device(&device)?; + println!("loaded image {image:?}"); + + let model_file = match args.model { + None => { + let api = hf_hub::api::sync::Api::new()?; + let api = api.model("vincent-espitalier/candle-eva2".into()); + api.get("eva02_base_patch14_448.mim_in22k_ft_in22k_in1k_adapted.safetensors")? + } + Some(model) => model.into(), + }; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; + + let model = eva2::vit_base(vb)?; + println!("model built"); + let logits = model.forward(&image.unsqueeze(0)?)?; + let prs = candle_nn::ops::softmax(&logits, D::Minus1)? + .i(0)? + .to_vec1::<f32>()?; + let mut prs = prs.iter().enumerate().collect::<Vec<_>>(); + prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1)); + for &(category_idx, pr) in prs.iter().take(5) { + println!( + "{:24}: {:.2}%", + candle_examples::imagenet::CLASSES[category_idx], + 100. * pr + ); + } + Ok(()) +} |