summaryrefslogtreecommitdiff
path: root/candle-examples/examples/eva2
diff options
context:
space:
mode:
authorv-espitalier <125037408+v-espitalier@users.noreply.github.com>2024-07-07 20:09:31 +0200
committerGitHub <noreply@github.com>2024-07-07 20:09:31 +0200
commit9cd54aa5d4fb6cf30e0df2d198c8a387db2d9144 (patch)
tree9988e786128416cf8c77425658e29a67c904a5ad /candle-examples/examples/eva2
parenteec11ce2ce1e81f0fdb1cac5405d07286242dc01 (diff)
downloadcandle-9cd54aa5d4fb6cf30e0df2d198c8a387db2d9144.tar.gz
candle-9cd54aa5d4fb6cf30e0df2d198c8a387db2d9144.tar.bz2
candle-9cd54aa5d4fb6cf30e0df2d198c8a387db2d9144.zip
Add EVA-02 model ( https://arxiv.org/abs/2303.11331 ) (#2311)
* Add EVA-02 model ( https://arxiv.org/abs/2303.11331 ) * Clippy fix. * And apply fmt. --------- Co-authored-by: v-espitalier <> Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-examples/examples/eva2')
-rw-r--r--candle-examples/examples/eva2/README.md21
-rw-r--r--candle-examples/examples/eva2/main.rs82
2 files changed, 103 insertions, 0 deletions
diff --git a/candle-examples/examples/eva2/README.md b/candle-examples/examples/eva2/README.md
new file mode 100644
index 00000000..10c91b89
--- /dev/null
+++ b/candle-examples/examples/eva2/README.md
@@ -0,0 +1,21 @@
+# candle-eva2
+
+[EVA-02](https://arxiv.org/abs/2303.11331) is a computer vision model.
+In this example, it is used as an ImageNet classifier: the model returns the
+probability for the image to belong to each of the 1000 ImageNet categories.
+
+## Running some example
+
+```bash
+cargo run --example eva2 --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
+
+> mountain bike, all-terrain bike, off-roader: 37.09%
+> maillot : 8.30%
+> alp : 2.13%
+> bicycle-built-for-two, tandem bicycle, tandem: 0.84%
+> crash helmet : 0.73%
+
+
+```
+
+![Leading group, Giro d'Italia 2021](../yolo-v8/assets/bike.jpg)
diff --git a/candle-examples/examples/eva2/main.rs b/candle-examples/examples/eva2/main.rs
new file mode 100644
index 00000000..4270075d
--- /dev/null
+++ b/candle-examples/examples/eva2/main.rs
@@ -0,0 +1,82 @@
+//! EVA-02: Explore the limits of Visual representation at scAle
+//! https://github.com/baaivision/EVA
+
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
+use clap::Parser;
+
+use candle::{DType, Device, IndexOp, Result, Tensor, D};
+use candle_nn::{Module, VarBuilder};
+use candle_transformers::models::eva2;
+
+/// Loads an image from disk using the image crate, this returns a tensor with shape
+/// (3, 448, 448). OpenAI normalization is applied.
+pub fn load_image448_openai_norm<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
+ let img = image::io::Reader::open(p)?
+ .decode()
+ .map_err(candle::Error::wrap)?
+ .resize_to_fill(448, 448, image::imageops::FilterType::Triangle);
+ let img = img.to_rgb8();
+ let data = img.into_raw();
+ let data = Tensor::from_vec(data, (448, 448, 3), &Device::Cpu)?.permute((2, 0, 1))?;
+ let mean =
+ Tensor::new(&[0.48145466f32, 0.4578275, 0.40821073], &Device::Cpu)?.reshape((3, 1, 1))?;
+ let std = Tensor::new(&[0.26862954f32, 0.261_302_6, 0.275_777_1], &Device::Cpu)?
+ .reshape((3, 1, 1))?;
+ (data.to_dtype(candle::DType::F32)? / 255.)?
+ .broadcast_sub(&mean)?
+ .broadcast_div(&std)
+}
+
+#[derive(Parser)]
+struct Args {
+ #[arg(long)]
+ model: Option<String>,
+
+ #[arg(long)]
+ image: String,
+
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+}
+
+pub fn main() -> anyhow::Result<()> {
+ let args = Args::parse();
+
+ let device = candle_examples::device(args.cpu)?;
+
+ let image = load_image448_openai_norm(args.image)?.to_device(&device)?;
+ println!("loaded image {image:?}");
+
+ let model_file = match args.model {
+ None => {
+ let api = hf_hub::api::sync::Api::new()?;
+ let api = api.model("vincent-espitalier/candle-eva2".into());
+ api.get("eva02_base_patch14_448.mim_in22k_ft_in22k_in1k_adapted.safetensors")?
+ }
+ Some(model) => model.into(),
+ };
+ let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
+
+ let model = eva2::vit_base(vb)?;
+ println!("model built");
+ let logits = model.forward(&image.unsqueeze(0)?)?;
+ let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
+ .i(0)?
+ .to_vec1::<f32>()?;
+ let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
+ prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
+ for &(category_idx, pr) in prs.iter().take(5) {
+ println!(
+ "{:24}: {:.2}%",
+ candle_examples::imagenet::CLASSES[category_idx],
+ 100. * pr
+ );
+ }
+ Ok(())
+}