diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-11-05 16:57:26 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-05 16:57:26 +0100 |
commit | f365a075e551dd50f7def29ecc2d8cba100c4625 (patch) | |
tree | 989c49f3317c59941284048c498e7801151b5e7c /candle-examples | |
parent | 60fdab4e17d3e420f20610ec75df3deccd8e1f69 (diff) | |
download | candle-f365a075e551dd50f7def29ecc2d8cba100c4625.tar.gz candle-f365a075e551dd50f7def29ecc2d8cba100c4625.tar.bz2 candle-f365a075e551dd50f7def29ecc2d8cba100c4625.zip |
Add more models to the onnx example. (#1273)
* Add more models to the onnx example.
* Input validation.
* Input validation.
* Bugfix.
* Implement clip.
* BatchNorm support.
* Get the efficientnet onnx to work.
Diffstat (limited to 'candle-examples')
-rw-r--r-- | candle-examples/examples/onnx/README.md (renamed from candle-examples/examples/squeezenet-onnx/README.md) | 0 | ||||
-rw-r--r-- | candle-examples/examples/onnx/main.rs (renamed from candle-examples/examples/squeezenet-onnx/main.rs) | 37 |
2 files changed, 29 insertions, 8 deletions
diff --git a/candle-examples/examples/squeezenet-onnx/README.md b/candle-examples/examples/onnx/README.md index fd705fb6..fd705fb6 100644 --- a/candle-examples/examples/squeezenet-onnx/README.md +++ b/candle-examples/examples/onnx/README.md diff --git a/candle-examples/examples/squeezenet-onnx/main.rs b/candle-examples/examples/onnx/main.rs index 90a38bf0..d3b0f8f8 100644 --- a/candle-examples/examples/squeezenet-onnx/main.rs +++ b/candle-examples/examples/onnx/main.rs @@ -5,7 +5,13 @@ extern crate intel_mkl_src; extern crate accelerate_src; use candle::{IndexOp, D}; -use clap::Parser; +use clap::{Parser, ValueEnum}; + +#[derive(Clone, Copy, Debug, ValueEnum)] +enum Which { + SqueezeNet, + EfficientNet, +} #[derive(Parser)] struct Args { @@ -14,19 +20,32 @@ struct Args { #[arg(long)] model: Option<String>, + + /// The model to be used. + #[arg(value_enum, long, default_value_t = Which::SqueezeNet)] + which: Which, } pub fn main() -> anyhow::Result<()> { let args = Args::parse(); let image = candle_examples::imagenet::load_image224(args.image)?; + let image = match args.which { + Which::SqueezeNet => image, + Which::EfficientNet => image.permute((1, 2, 0))?, + }; println!("loaded image {image:?}"); let model = match args.model { Some(model) => std::path::PathBuf::from(model), - None => hf_hub::api::sync::Api::new()? - .model("lmz/candle-onnx".into()) - .get("squeezenet1.1-7.onnx")?, + None => match args.which { + Which::SqueezeNet => hf_hub::api::sync::Api::new()? + .model("lmz/candle-onnx".into()) + .get("squeezenet1.1-7.onnx")?, + Which::EfficientNet => hf_hub::api::sync::Api::new()? + .model("onnx/EfficientNet-Lite4".into()) + .get("efficientnet-lite4-11.onnx")?, + }, }; let model = candle_onnx::read_file(model)?; @@ -34,10 +53,12 @@ pub fn main() -> anyhow::Result<()> { let mut inputs = std::collections::HashMap::new(); inputs.insert(graph.input[0].name.to_string(), image.unsqueeze(0)?); let mut outputs = candle_onnx::simple_eval(&model, inputs)?; - let logits = outputs.remove(&graph.output[0].name).unwrap(); - let prs = candle_nn::ops::softmax(&logits, D::Minus1)? - .i(0)? - .to_vec1::<f32>()?; + let output = outputs.remove(&graph.output[0].name).unwrap(); + let prs = match args.which { + Which::SqueezeNet => candle_nn::ops::softmax(&output, D::Minus1)?, + Which::EfficientNet => output, + }; + let prs = prs.i(0)?.to_vec1::<f32>()?; // Sort the predictions and take the top 5 let mut top: Vec<_> = prs.iter().enumerate().collect(); |