diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-19 13:48:28 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-19 13:48:28 +0100 |
commit | 93c25e8844e8db2c697f1e6e9d4a06dac1ca3569 (patch) | |
tree | bbc00678fd08bbc8cfc392e5ff589477e80172b9 /candle-examples/examples/resnet/main.rs | |
parent | cd53c472df163b3baaf7c70ca5d4f8909af62324 (diff) | |
download | candle-93c25e8844e8db2c697f1e6e9d4a06dac1ca3569.tar.gz candle-93c25e8844e8db2c697f1e6e9d4a06dac1ca3569.tar.bz2 candle-93c25e8844e8db2c697f1e6e9d4a06dac1ca3569.zip |
Expose the larger resnets (50/101/152) in the example. (#1131)
Diffstat (limited to 'candle-examples/examples/resnet/main.rs')
-rw-r--r-- | candle-examples/examples/resnet/main.rs | 14 |
1 files changed, 14 insertions, 0 deletions
diff --git a/candle-examples/examples/resnet/main.rs b/candle-examples/examples/resnet/main.rs index 3badc48e..4a4592ad 100644 --- a/candle-examples/examples/resnet/main.rs +++ b/candle-examples/examples/resnet/main.rs @@ -11,8 +11,16 @@ use clap::{Parser, ValueEnum}; #[derive(Clone, Copy, Debug, ValueEnum)] enum Which { + #[value(name = "18")] Resnet18, + #[value(name = "34")] Resnet34, + #[value(name = "50")] + Resnet50, + #[value(name = "101")] + Resnet101, + #[value(name = "152")] + Resnet152, } #[derive(Parser)] @@ -47,6 +55,9 @@ pub fn main() -> anyhow::Result<()> { let filename = match args.which { Which::Resnet18 => "resnet18.safetensors", Which::Resnet34 => "resnet34.safetensors", + Which::Resnet50 => "resnet50.safetensors", + Which::Resnet101 => "resnet101.safetensors", + Which::Resnet152 => "resnet152.safetensors", }; api.get(filename)? } @@ -57,6 +68,9 @@ pub fn main() -> anyhow::Result<()> { let model = match args.which { Which::Resnet18 => resnet::resnet18(class_count, vb)?, Which::Resnet34 => resnet::resnet34(class_count, vb)?, + Which::Resnet50 => resnet::resnet50(class_count, vb)?, + Which::Resnet101 => resnet::resnet101(class_count, vb)?, + Which::Resnet152 => resnet::resnet152(class_count, vb)?, }; println!("model built"); let logits = model.forward(&image.unsqueeze(0)?)?; |