diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-18 17:32:58 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-18 17:32:58 +0100 |
commit | e5dd5fd1b3cdff3ac5817d2061a54e397dcb0aae (patch) | |
tree | 8a81dcd3798158f6e1a5445e932625b2662981d1 /candle-examples/examples/dinov2 | |
parent | cb069d606323cec02c6bf54185c2fbfffffd4bdf (diff) | |
download | candle-e5dd5fd1b3cdff3ac5817d2061a54e397dcb0aae.tar.gz candle-e5dd5fd1b3cdff3ac5817d2061a54e397dcb0aae.tar.bz2 candle-e5dd5fd1b3cdff3ac5817d2061a54e397dcb0aae.zip |
Print the recognized categories in dino-v2. (#506)
Diffstat (limited to 'candle-examples/examples/dinov2')
-rw-r--r-- | candle-examples/examples/dinov2/main.rs | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/candle-examples/examples/dinov2/main.rs b/candle-examples/examples/dinov2/main.rs index 0d6c4b85..2de28459 100644 --- a/candle-examples/examples/dinov2/main.rs +++ b/candle-examples/examples/dinov2/main.rs @@ -315,7 +315,17 @@ pub fn main() -> anyhow::Result<()> { let model = vit_small(vb)?; println!("model built"); let logits = model.forward(&image.unsqueeze(0)?)?; - let prs = candle_nn::ops::softmax(&logits, D::Minus1)?; - println!("{prs}"); + let prs = candle_nn::ops::softmax(&logits, D::Minus1)? + .i(0)? + .to_vec1::<f32>()?; + let mut prs = prs.iter().enumerate().collect::<Vec<_>>(); + prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1)); + for &(category_idx, pr) in prs.iter().take(5) { + println!( + "{:24}: {:.2}%", + candle_examples::IMAGENET_CLASSES[category_idx], + 100. * pr + ); + } Ok(()) } |