summaryrefslogtreecommitdiff
path: root/candle-examples
diff options
context:
space:
mode:
authorv-espitalier <125037408+v-espitalier@users.noreply.github.com>2024-06-29 11:49:15 +0200
committerGitHub <noreply@github.com>2024-06-29 11:49:15 +0200
commite27aac0a062a6de125e2984eacdb7841664e86fd (patch)
treea0752c27f75da6c7312abb2a2219d9179e89d8db /candle-examples
parenta3dd87f15e3656ee2bec4820ae72a2a4e5662b40 (diff)
downloadcandle-e27aac0a062a6de125e2984eacdb7841664e86fd.tar.gz
candle-e27aac0a062a6de125e2984eacdb7841664e86fd.tar.bz2
candle-e27aac0a062a6de125e2984eacdb7841664e86fd.zip
Add DINOv2Reg4 + PlantCLEF2024 (#2293)
* Add: DINOv2Reg4 with PlantCLEF2024 weights and example ( See https://arxiv.org/abs/2309.16588 and https://zenodo.org/records/10848263 ) * Remove extra files + update README to download them + remove extra lines * minor fix (README remove extra spaces) * minor fix (README: Fix image url) * Modif: Add back interpolate_pos_encoding() + fix when no interpolation + remove extra comments + Update README ( source image changed and so the predictions ) * Fix: Improve code lisibility with '$ cargo clippy' and '$ cargo fmt' * Another clippy fix. --------- Co-authored-by: x-VEspit <vincent.espitalier@cirad.fr> Co-authored-by: laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-examples')
-rw-r--r--candle-examples/examples/dinov2reg4/README.md25
-rw-r--r--candle-examples/examples/dinov2reg4/main.rs70
-rw-r--r--candle-examples/src/imagenet.rs18
3 files changed, 113 insertions, 0 deletions
diff --git a/candle-examples/examples/dinov2reg4/README.md b/candle-examples/examples/dinov2reg4/README.md
new file mode 100644
index 00000000..ac86ca69
--- /dev/null
+++ b/candle-examples/examples/dinov2reg4/README.md
@@ -0,0 +1,25 @@
+# candle-dinov2-reg4
+
+[DINOv2-reg4](https://arxiv.org/abs/2309.16588) is the lastest version of DINOv2 with registers.
+In this example, it is used as an plant species classifier: the model returns the
+probability for the image to belong to each of the 7806 PlantCLEF2024 categories.
+
+## Running some example
+
+```bash
+# Download classes names and a plant picture to identify
+curl https://huggingface.co/vincent-espitalier/dino-v2-reg4-with-plantclef2024-weights/raw/main/species_id_mapping.txt --output candle-examples/examples/dinov2reg4/species_id_mapping.txt
+curl https://bs.plantnet.org/image/o/bd2d3830ac3270218ba82fd24e2290becd01317c --output candle-examples/examples/dinov2reg4/bd2d3830ac3270218ba82fd24e2290becd01317c.jpg
+
+# Perform inference
+cargo run --example dinov2reg4 --release -- --image candle-examples/examples/dinov2reg4/bd2d3830ac3270218ba82fd24e2290becd01317c.jpg
+
+> Orchis simia Lam. : 45.55%
+> Orchis × bergonii Nanteuil: 9.80%
+> Orchis italica Poir. : 9.66%
+> Orchis × angusticruris Franch.: 2.76%
+> Orchis × bivonae Tod. : 2.54%
+
+```
+
+![Orchis Simia](https://bs.plantnet.org/image/o/bd2d3830ac3270218ba82fd24e2290becd01317c)
diff --git a/candle-examples/examples/dinov2reg4/main.rs b/candle-examples/examples/dinov2reg4/main.rs
new file mode 100644
index 00000000..15270517
--- /dev/null
+++ b/candle-examples/examples/dinov2reg4/main.rs
@@ -0,0 +1,70 @@
+//! DINOv2 reg4 finetuned on PlantCLEF 2024
+//! https://arxiv.org/abs/2309.16588
+//! https://huggingface.co/spaces/BVRA/PlantCLEF2024
+//! https://zenodo.org/records/10848263
+
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
+use clap::Parser;
+
+use candle::{DType, IndexOp, D};
+use candle_nn::{Module, VarBuilder};
+use candle_transformers::models::dinov2reg4;
+
+#[derive(Parser)]
+struct Args {
+ #[arg(long)]
+ model: Option<String>,
+
+ #[arg(long)]
+ image: String,
+
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+}
+
+pub fn main() -> anyhow::Result<()> {
+ let args = Args::parse();
+
+ let device = candle_examples::device(args.cpu)?;
+
+ let image = candle_examples::imagenet::load_image518(args.image)?.to_device(&device)?;
+ println!("loaded image {image:?}");
+
+ let f_species_id_mapping = "candle-examples/examples/dinov2reg4/species_id_mapping.txt";
+ let classes: Vec<String> = std::fs::read_to_string(f_species_id_mapping)
+ .expect("missing classes file")
+ .split('\n')
+ .map(|s| s.to_string())
+ .collect();
+
+ let model_file = match args.model {
+ None => {
+ let api = hf_hub::api::sync::Api::new()?;
+ let api =
+ api.model("vincent-espitalier/dino-v2-reg4-with-plantclef2024-weights".into());
+ api.get(
+ "vit_base_patch14_reg4_dinov2_lvd142m_pc24_onlyclassifier_then_all.safetensors",
+ )?
+ }
+ Some(model) => model.into(),
+ };
+ let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
+ let model = dinov2reg4::vit_base(vb)?;
+ println!("model built");
+ let logits = model.forward(&image.unsqueeze(0)?)?;
+ let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
+ .i(0)?
+ .to_vec1::<f32>()?;
+ let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
+ prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
+ for &(category_idx, pr) in prs.iter().take(5) {
+ println!("{:24}: {:.2}%", classes[category_idx], 100. * pr);
+ }
+ Ok(())
+}
diff --git a/candle-examples/src/imagenet.rs b/candle-examples/src/imagenet.rs
index cefbd71b..781dcd4f 100644
--- a/candle-examples/src/imagenet.rs
+++ b/candle-examples/src/imagenet.rs
@@ -17,6 +17,24 @@ pub fn load_image224<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
.broadcast_div(&std)
}
+/// Loads an image from disk using the image crate, this returns a tensor with shape
+/// (3, 518, 518). imagenet normalization is applied.
+/// The model dinov2 reg4 analyzes images with dimensions 3x518x518 (resulting in 37x37 transformer tokens).
+pub fn load_image518<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
+ let img = image::io::Reader::open(p)?
+ .decode()
+ .map_err(candle::Error::wrap)?
+ .resize_to_fill(518, 518, image::imageops::FilterType::Triangle);
+ let img = img.to_rgb8();
+ let data = img.into_raw();
+ let data = Tensor::from_vec(data, (518, 518, 3), &Device::Cpu)?.permute((2, 0, 1))?;
+ let mean = Tensor::new(&[0.485f32, 0.456, 0.406], &Device::Cpu)?.reshape((3, 1, 1))?;
+ let std = Tensor::new(&[0.229f32, 0.224, 0.225], &Device::Cpu)?.reshape((3, 1, 1))?;
+ (data.to_dtype(candle::DType::F32)? / 255.)?
+ .broadcast_sub(&mean)?
+ .broadcast_div(&std)
+}
+
pub const CLASS_COUNT: i64 = 1000;
pub const CLASSES: [&str; 1000] = [