diff options
Diffstat (limited to 'candle-examples/examples/resnet/main.rs')
-rw-r--r-- | candle-examples/examples/resnet/main.rs | 14 |
1 files changed, 14 insertions, 0 deletions
diff --git a/candle-examples/examples/resnet/main.rs b/candle-examples/examples/resnet/main.rs index 3badc48e..4a4592ad 100644 --- a/candle-examples/examples/resnet/main.rs +++ b/candle-examples/examples/resnet/main.rs @@ -11,8 +11,16 @@ use clap::{Parser, ValueEnum}; #[derive(Clone, Copy, Debug, ValueEnum)] enum Which { + #[value(name = "18")] Resnet18, + #[value(name = "34")] Resnet34, + #[value(name = "50")] + Resnet50, + #[value(name = "101")] + Resnet101, + #[value(name = "152")] + Resnet152, } #[derive(Parser)] @@ -47,6 +55,9 @@ pub fn main() -> anyhow::Result<()> { let filename = match args.which { Which::Resnet18 => "resnet18.safetensors", Which::Resnet34 => "resnet34.safetensors", + Which::Resnet50 => "resnet50.safetensors", + Which::Resnet101 => "resnet101.safetensors", + Which::Resnet152 => "resnet152.safetensors", }; api.get(filename)? } @@ -57,6 +68,9 @@ pub fn main() -> anyhow::Result<()> { let model = match args.which { Which::Resnet18 => resnet::resnet18(class_count, vb)?, Which::Resnet34 => resnet::resnet34(class_count, vb)?, + Which::Resnet50 => resnet::resnet50(class_count, vb)?, + Which::Resnet101 => resnet::resnet101(class_count, vb)?, + Which::Resnet152 => resnet::resnet152(class_count, vb)?, }; println!("model built"); let logits = model.forward(&image.unsqueeze(0)?)?; |