diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-11 20:51:10 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-11 19:51:10 +0100 |
commit | e7560443e4680b7655d011948d3cf178268fcfff (patch) | |
tree | fe7f26799c6afee59c30c9c33b24a390e397aae2 /candle-examples/examples/convmixer/main.rs | |
parent | 89b525b5e758218179dd32293e7167e3aae1b28f (diff) | |
download | candle-e7560443e4680b7655d011948d3cf178268fcfff.tar.gz candle-e7560443e4680b7655d011948d3cf178268fcfff.tar.bz2 candle-e7560443e4680b7655d011948d3cf178268fcfff.zip |
Convmixer example (#1074)
* Add a convmixer based example.
* Mention the model in the readme.
Diffstat (limited to 'candle-examples/examples/convmixer/main.rs')
-rw-r--r-- | candle-examples/examples/convmixer/main.rs | 59 |
1 files changed, 59 insertions, 0 deletions
diff --git a/candle-examples/examples/convmixer/main.rs b/candle-examples/examples/convmixer/main.rs new file mode 100644 index 00000000..feae536f --- /dev/null +++ b/candle-examples/examples/convmixer/main.rs @@ -0,0 +1,59 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use clap::Parser; + +use candle::{DType, IndexOp, D}; +use candle_nn::{Module, VarBuilder}; +use candle_transformers::models::convmixer; + +#[derive(Parser)] +struct Args { + #[arg(long)] + model: Option<String>, + + #[arg(long)] + image: String, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, +} + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + + let device = candle_examples::device(args.cpu)?; + + let image = candle_examples::imagenet::load_image224(args.image)?; + println!("loaded image {image:?}"); + + let model_file = match args.model { + None => { + let api = hf_hub::api::sync::Api::new()?; + let api = api.model("lmz/candle-convmixer".into()); + api.get("convmixer_1024_20_ks9_p14.safetensors")? + } + Some(model) => model.into(), + }; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; + let model = convmixer::c1024_20(1000, vb)?; + println!("model built"); + let logits = model.forward(&image.unsqueeze(0)?)?; + let prs = candle_nn::ops::softmax(&logits, D::Minus1)? + .i(0)? + .to_vec1::<f32>()?; + let mut prs = prs.iter().enumerate().collect::<Vec<_>>(); + prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1)); + for &(category_idx, pr) in prs.iter().take(5) { + println!( + "{:24}: {:.2}%", + candle_examples::imagenet::CLASSES[category_idx], + 100. * pr + ); + } + Ok(()) +} |