summaryrefslogtreecommitdiff
path: root/candle-examples/examples/resnet/main.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-19 13:48:28 +0100
committerGitHub <noreply@github.com>2023-10-19 13:48:28 +0100
commit93c25e8844e8db2c697f1e6e9d4a06dac1ca3569 (patch)
treebbc00678fd08bbc8cfc392e5ff589477e80172b9 /candle-examples/examples/resnet/main.rs
parentcd53c472df163b3baaf7c70ca5d4f8909af62324 (diff)
downloadcandle-93c25e8844e8db2c697f1e6e9d4a06dac1ca3569.tar.gz
candle-93c25e8844e8db2c697f1e6e9d4a06dac1ca3569.tar.bz2
candle-93c25e8844e8db2c697f1e6e9d4a06dac1ca3569.zip
Expose the larger resnets (50/101/152) in the example. (#1131)
Diffstat (limited to 'candle-examples/examples/resnet/main.rs')
-rw-r--r--candle-examples/examples/resnet/main.rs14
1 files changed, 14 insertions, 0 deletions
diff --git a/candle-examples/examples/resnet/main.rs b/candle-examples/examples/resnet/main.rs
index 3badc48e..4a4592ad 100644
--- a/candle-examples/examples/resnet/main.rs
+++ b/candle-examples/examples/resnet/main.rs
@@ -11,8 +11,16 @@ use clap::{Parser, ValueEnum};
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
+ #[value(name = "18")]
Resnet18,
+ #[value(name = "34")]
Resnet34,
+ #[value(name = "50")]
+ Resnet50,
+ #[value(name = "101")]
+ Resnet101,
+ #[value(name = "152")]
+ Resnet152,
}
#[derive(Parser)]
@@ -47,6 +55,9 @@ pub fn main() -> anyhow::Result<()> {
let filename = match args.which {
Which::Resnet18 => "resnet18.safetensors",
Which::Resnet34 => "resnet34.safetensors",
+ Which::Resnet50 => "resnet50.safetensors",
+ Which::Resnet101 => "resnet101.safetensors",
+ Which::Resnet152 => "resnet152.safetensors",
};
api.get(filename)?
}
@@ -57,6 +68,9 @@ pub fn main() -> anyhow::Result<()> {
let model = match args.which {
Which::Resnet18 => resnet::resnet18(class_count, vb)?,
Which::Resnet34 => resnet::resnet34(class_count, vb)?,
+ Which::Resnet50 => resnet::resnet50(class_count, vb)?,
+ Which::Resnet101 => resnet::resnet101(class_count, vb)?,
+ Which::Resnet152 => resnet::resnet152(class_count, vb)?,
};
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;