diff options
Diffstat (limited to 'candle-examples/examples/yolo-v8/main.rs')
-rw-r--r-- | candle-examples/examples/yolo-v8/main.rs | 27 |
1 files changed, 24 insertions, 3 deletions
diff --git a/candle-examples/examples/yolo-v8/main.rs b/candle-examples/examples/yolo-v8/main.rs index d48bac35..dc709db4 100644 --- a/candle-examples/examples/yolo-v8/main.rs +++ b/candle-examples/examples/yolo-v8/main.rs @@ -7,7 +7,7 @@ extern crate accelerate_src; mod model; use model::{Multiples, YoloV8, YoloV8Pose}; -use candle::{DType, Device, IndexOp, Result, Tensor}; +use candle::{DType, IndexOp, Result, Tensor}; use candle_nn::{Module, VarBuilder}; use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint}; use clap::{Parser, ValueEnum}; @@ -253,6 +253,14 @@ enum YoloTask { #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] pub struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + /// Model weights, in safetensors format. #[arg(long)] model: Option<String>, @@ -363,6 +371,7 @@ impl Task for YoloV8Pose { } pub fn run<T: Task>(args: Args) -> anyhow::Result<()> { + let device = candle_examples::device(args.cpu)?; // Create the model and load the weights from the file. let multiples = match args.which { Which::N => Multiples::n(), @@ -374,7 +383,7 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> { let model = args.model()?; let weights = unsafe { candle::safetensors::MmapedFile::new(model)? }; let weights = weights.deserialize()?; - let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu); + let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device); let model = T::load(vb, multiples)?; println!("model loaded"); for image_name in args.images.iter() { @@ -405,7 +414,7 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> { Tensor::from_vec( data, (img.height() as usize, img.width() as usize, 3), - &Device::Cpu, + &device, )? .permute((2, 0, 1))? }; @@ -430,7 +439,19 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> { } pub fn main() -> anyhow::Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + let args = Args::parse(); + + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + match args.task { YoloTask::Detect => run::<YoloV8>(args)?, YoloTask::Pose => run::<YoloV8Pose>(args)?, |