summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-21 10:55:09 +0100
committerGitHub <noreply@github.com>2023-08-21 10:55:09 +0100
commite3b71851e6d2c9eb51bb9978e7b025386d336f61 (patch)
tree610ae746c076cfe82226ad5040f6efe167b1afa5
parent4300864ce928cc901818771e5e1bb9202b96ede5 (diff)
downloadcandle-e3b71851e6d2c9eb51bb9978e7b025386d336f61.tar.gz
candle-e3b71851e6d2c9eb51bb9978e7b025386d336f61.tar.bz2
candle-e3b71851e6d2c9eb51bb9978e7b025386d336f61.zip
Retrieve the yolo-v3 weights from the hub. (#537)
-rw-r--r--candle-examples/examples/yolo-v3/main.rs36
1 files changed, 32 insertions, 4 deletions
diff --git a/candle-examples/examples/yolo-v3/main.rs b/candle-examples/examples/yolo-v3/main.rs
index a52f731c..0c7bdd7b 100644
--- a/candle-examples/examples/yolo-v3/main.rs
+++ b/candle-examples/examples/yolo-v3/main.rs
@@ -130,22 +130,50 @@ pub fn report(pred: &Tensor, img: DynamicImage, w: usize, h: usize) -> Result<Dy
struct Args {
/// Model weights, in safetensors format.
#[arg(long)]
- model: String,
+ model: Option<String>,
#[arg(long)]
- config: String,
+ config: Option<String>,
images: Vec<String>,
}
+impl Args {
+ fn config(&self) -> anyhow::Result<std::path::PathBuf> {
+ let path = match &self.config {
+ Some(config) => std::path::PathBuf::from(config),
+ None => {
+ let api = hf_hub::api::sync::Api::new()?;
+ let api = api.model("lmz/candle-yolo-v3".to_string());
+ api.get("yolo-v3.cfg")?
+ }
+ };
+ Ok(path)
+ }
+
+ fn model(&self) -> anyhow::Result<std::path::PathBuf> {
+ let path = match &self.model {
+ Some(model) => std::path::PathBuf::from(model),
+ None => {
+ let api = hf_hub::api::sync::Api::new()?;
+ let api = api.model("lmz/candle-yolo-v3".to_string());
+ api.get("yolo-v3.safetensors")?
+ }
+ };
+ Ok(path)
+ }
+}
+
pub fn main() -> Result<()> {
let args = Args::parse();
// Create the model and load the weights from the file.
- let weights = unsafe { candle::safetensors::MmapedFile::new(&args.model)? };
+ let model = args.model()?;
+ let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu);
- let darknet = darknet::parse_config(&args.config)?;
+ let config = args.config()?;
+ let darknet = darknet::parse_config(config)?;
let model = darknet.build_model(vb)?;
for image_name in args.images.iter() {