summaryrefslogtreecommitdiff
path: root/candle-examples/examples/yolo-v8/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/yolo-v8/main.rs')
-rw-r--r--candle-examples/examples/yolo-v8/main.rs27
1 files changed, 24 insertions, 3 deletions
diff --git a/candle-examples/examples/yolo-v8/main.rs b/candle-examples/examples/yolo-v8/main.rs
index d48bac35..dc709db4 100644
--- a/candle-examples/examples/yolo-v8/main.rs
+++ b/candle-examples/examples/yolo-v8/main.rs
@@ -7,7 +7,7 @@ extern crate accelerate_src;
mod model;
use model::{Multiples, YoloV8, YoloV8Pose};
-use candle::{DType, Device, IndexOp, Result, Tensor};
+use candle::{DType, IndexOp, Result, Tensor};
use candle_nn::{Module, VarBuilder};
use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
use clap::{Parser, ValueEnum};
@@ -253,6 +253,14 @@ enum YoloTask {
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub struct Args {
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+
+ /// Enable tracing (generates a trace-timestamp.json file).
+ #[arg(long)]
+ tracing: bool,
+
/// Model weights, in safetensors format.
#[arg(long)]
model: Option<String>,
@@ -363,6 +371,7 @@ impl Task for YoloV8Pose {
}
pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
+ let device = candle_examples::device(args.cpu)?;
// Create the model and load the weights from the file.
let multiples = match args.which {
Which::N => Multiples::n(),
@@ -374,7 +383,7 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
let model = args.model()?;
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
let weights = weights.deserialize()?;
- let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu);
+ let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
let model = T::load(vb, multiples)?;
println!("model loaded");
for image_name in args.images.iter() {
@@ -405,7 +414,7 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
Tensor::from_vec(
data,
(img.height() as usize, img.width() as usize, 3),
- &Device::Cpu,
+ &device,
)?
.permute((2, 0, 1))?
};
@@ -430,7 +439,19 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
}
pub fn main() -> anyhow::Result<()> {
+ use tracing_chrome::ChromeLayerBuilder;
+ use tracing_subscriber::prelude::*;
+
let args = Args::parse();
+
+ let _guard = if args.tracing {
+ let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
+ tracing_subscriber::registry().with(chrome_layer).init();
+ Some(guard)
+ } else {
+ None
+ };
+
match args.task {
YoloTask::Detect => run::<YoloV8>(args)?,
YoloTask::Pose => run::<YoloV8Pose>(args)?,