summaryrefslogtreecommitdiff
path: root/candle-examples/examples/dinov2
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-18 17:32:58 +0100
committerGitHub <noreply@github.com>2023-08-18 17:32:58 +0100
commite5dd5fd1b3cdff3ac5817d2061a54e397dcb0aae (patch)
tree8a81dcd3798158f6e1a5445e932625b2662981d1 /candle-examples/examples/dinov2
parentcb069d606323cec02c6bf54185c2fbfffffd4bdf (diff)
downloadcandle-e5dd5fd1b3cdff3ac5817d2061a54e397dcb0aae.tar.gz
candle-e5dd5fd1b3cdff3ac5817d2061a54e397dcb0aae.tar.bz2
candle-e5dd5fd1b3cdff3ac5817d2061a54e397dcb0aae.zip
Print the recognized categories in dino-v2. (#506)
Diffstat (limited to 'candle-examples/examples/dinov2')
-rw-r--r--candle-examples/examples/dinov2/main.rs14
1 files changed, 12 insertions, 2 deletions
diff --git a/candle-examples/examples/dinov2/main.rs b/candle-examples/examples/dinov2/main.rs
index 0d6c4b85..2de28459 100644
--- a/candle-examples/examples/dinov2/main.rs
+++ b/candle-examples/examples/dinov2/main.rs
@@ -315,7 +315,17 @@ pub fn main() -> anyhow::Result<()> {
let model = vit_small(vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
- let prs = candle_nn::ops::softmax(&logits, D::Minus1)?;
- println!("{prs}");
+ let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
+ .i(0)?
+ .to_vec1::<f32>()?;
+ let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
+ prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
+ for &(category_idx, pr) in prs.iter().take(5) {
+ println!(
+ "{:24}: {:.2}%",
+ candle_examples::IMAGENET_CLASSES[category_idx],
+ 100. * pr
+ );
+ }
Ok(())
}