diff options
-rw-r--r-- | candle-examples/examples/yolo-v3/main.rs | 36 |
1 files changed, 32 insertions, 4 deletions
diff --git a/candle-examples/examples/yolo-v3/main.rs b/candle-examples/examples/yolo-v3/main.rs index a52f731c..0c7bdd7b 100644 --- a/candle-examples/examples/yolo-v3/main.rs +++ b/candle-examples/examples/yolo-v3/main.rs @@ -130,22 +130,50 @@ pub fn report(pred: &Tensor, img: DynamicImage, w: usize, h: usize) -> Result<Dy struct Args { /// Model weights, in safetensors format. #[arg(long)] - model: String, + model: Option<String>, #[arg(long)] - config: String, + config: Option<String>, images: Vec<String>, } +impl Args { + fn config(&self) -> anyhow::Result<std::path::PathBuf> { + let path = match &self.config { + Some(config) => std::path::PathBuf::from(config), + None => { + let api = hf_hub::api::sync::Api::new()?; + let api = api.model("lmz/candle-yolo-v3".to_string()); + api.get("yolo-v3.cfg")? + } + }; + Ok(path) + } + + fn model(&self) -> anyhow::Result<std::path::PathBuf> { + let path = match &self.model { + Some(model) => std::path::PathBuf::from(model), + None => { + let api = hf_hub::api::sync::Api::new()?; + let api = api.model("lmz/candle-yolo-v3".to_string()); + api.get("yolo-v3.safetensors")? + } + }; + Ok(path) + } +} + pub fn main() -> Result<()> { let args = Args::parse(); // Create the model and load the weights from the file. - let weights = unsafe { candle::safetensors::MmapedFile::new(&args.model)? }; + let model = args.model()?; + let weights = unsafe { candle::safetensors::MmapedFile::new(model)? }; let weights = weights.deserialize()?; let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu); - let darknet = darknet::parse_config(&args.config)?; + let config = args.config()?; + let darknet = darknet::parse_config(config)?; let model = darknet.build_model(vb)?; for image_name in args.images.iter() { |