summaryrefslogtreecommitdiff
path: root/candle-examples/examples/dinov2/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/dinov2/main.rs')
-rw-r--r--candle-examples/examples/dinov2/main.rs14
1 files changed, 12 insertions, 2 deletions
diff --git a/candle-examples/examples/dinov2/main.rs b/candle-examples/examples/dinov2/main.rs
index 0d6c4b85..2de28459 100644
--- a/candle-examples/examples/dinov2/main.rs
+++ b/candle-examples/examples/dinov2/main.rs
@@ -315,7 +315,17 @@ pub fn main() -> anyhow::Result<()> {
let model = vit_small(vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
- let prs = candle_nn::ops::softmax(&logits, D::Minus1)?;
- println!("{prs}");
+ let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
+ .i(0)?
+ .to_vec1::<f32>()?;
+ let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
+ prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
+ for &(category_idx, pr) in prs.iter().take(5) {
+ println!(
+ "{:24}: {:.2}%",
+ candle_examples::IMAGENET_CLASSES[category_idx],
+ 100. * pr
+ );
+ }
Ok(())
}