summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/yolo-v8/main.rs27
-rw-r--r--candle-examples/examples/yolo-v8/model.rs73
2 files changed, 90 insertions, 10 deletions
diff --git a/candle-examples/examples/yolo-v8/main.rs b/candle-examples/examples/yolo-v8/main.rs
index d48bac35..dc709db4 100644
--- a/candle-examples/examples/yolo-v8/main.rs
+++ b/candle-examples/examples/yolo-v8/main.rs
@@ -7,7 +7,7 @@ extern crate accelerate_src;
mod model;
use model::{Multiples, YoloV8, YoloV8Pose};
-use candle::{DType, Device, IndexOp, Result, Tensor};
+use candle::{DType, IndexOp, Result, Tensor};
use candle_nn::{Module, VarBuilder};
use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
use clap::{Parser, ValueEnum};
@@ -253,6 +253,14 @@ enum YoloTask {
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub struct Args {
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+
+ /// Enable tracing (generates a trace-timestamp.json file).
+ #[arg(long)]
+ tracing: bool,
+
/// Model weights, in safetensors format.
#[arg(long)]
model: Option<String>,
@@ -363,6 +371,7 @@ impl Task for YoloV8Pose {
}
pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
+ let device = candle_examples::device(args.cpu)?;
// Create the model and load the weights from the file.
let multiples = match args.which {
Which::N => Multiples::n(),
@@ -374,7 +383,7 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
let model = args.model()?;
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
let weights = weights.deserialize()?;
- let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu);
+ let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
let model = T::load(vb, multiples)?;
println!("model loaded");
for image_name in args.images.iter() {
@@ -405,7 +414,7 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
Tensor::from_vec(
data,
(img.height() as usize, img.width() as usize, 3),
- &Device::Cpu,
+ &device,
)?
.permute((2, 0, 1))?
};
@@ -430,7 +439,19 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
}
pub fn main() -> anyhow::Result<()> {
+ use tracing_chrome::ChromeLayerBuilder;
+ use tracing_subscriber::prelude::*;
+
let args = Args::parse();
+
+ let _guard = if args.tracing {
+ let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
+ tracing_subscriber::registry().with(chrome_layer).init();
+ Some(guard)
+ } else {
+ None
+ };
+
match args.task {
YoloTask::Detect => run::<YoloV8>(args)?,
YoloTask::Pose => run::<YoloV8Pose>(args)?,
diff --git a/candle-examples/examples/yolo-v8/model.rs b/candle-examples/examples/yolo-v8/model.rs
index b834f967..bf48fd84 100644
--- a/candle-examples/examples/yolo-v8/model.rs
+++ b/candle-examples/examples/yolo-v8/model.rs
@@ -77,6 +77,7 @@ impl Module for Upsample {
struct ConvBlock {
conv: Conv2d,
bn: BatchNorm,
+ span: tracing::Span,
}
impl ConvBlock {
@@ -97,12 +98,17 @@ impl ConvBlock {
};
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
- Ok(Self { conv, bn })
+ Ok(Self {
+ conv,
+ bn,
+ span: tracing::span!(tracing::Level::TRACE, "conv-block"),
+ })
}
}
impl Module for ConvBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
let xs = self.conv.forward(xs)?;
let xs = self.bn.forward(&xs)?;
candle_nn::ops::silu(&xs)
@@ -114,6 +120,7 @@ struct Bottleneck {
cv1: ConvBlock,
cv2: ConvBlock,
residual: bool,
+ span: tracing::Span,
}
impl Bottleneck {
@@ -123,12 +130,18 @@ impl Bottleneck {
let cv1 = ConvBlock::load(vb.pp("cv1"), c1, c_, 3, 1, None)?;
let cv2 = ConvBlock::load(vb.pp("cv2"), c_, c2, 3, 1, None)?;
let residual = c1 == c2 && shortcut;
- Ok(Self { cv1, cv2, residual })
+ Ok(Self {
+ cv1,
+ cv2,
+ residual,
+ span: tracing::span!(tracing::Level::TRACE, "bottleneck"),
+ })
}
}
impl Module for Bottleneck {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
let ys = self.cv2.forward(&self.cv1.forward(xs)?)?;
if self.residual {
xs + ys
@@ -143,6 +156,7 @@ struct C2f {
cv1: ConvBlock,
cv2: ConvBlock,
bottleneck: Vec<Bottleneck>,
+ span: tracing::Span,
}
impl C2f {
@@ -159,12 +173,14 @@ impl C2f {
cv1,
cv2,
bottleneck,
+ span: tracing::span!(tracing::Level::TRACE, "c2f"),
})
}
}
impl Module for C2f {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
let ys = self.cv1.forward(xs)?;
let mut ys = ys.chunk(2, 1)?;
for m in self.bottleneck.iter() {
@@ -180,6 +196,7 @@ struct Sppf {
cv1: ConvBlock,
cv2: ConvBlock,
k: usize,
+ span: tracing::Span,
}
impl Sppf {
@@ -187,12 +204,18 @@ impl Sppf {
let c_ = c1 / 2;
let cv1 = ConvBlock::load(vb.pp("cv1"), c1, c_, 1, 1, None)?;
let cv2 = ConvBlock::load(vb.pp("cv2"), c_ * 4, c2, 1, 1, None)?;
- Ok(Self { cv1, cv2, k })
+ Ok(Self {
+ cv1,
+ cv2,
+ k,
+ span: tracing::span!(tracing::Level::TRACE, "sppf"),
+ })
}
}
impl Module for Sppf {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
let (_, _, _, _) = xs.dims4()?;
let xs = self.cv1.forward(xs)?;
let xs2 = xs
@@ -215,17 +238,23 @@ impl Module for Sppf {
struct Dfl {
conv: Conv2d,
num_classes: usize,
+ span: tracing::Span,
}
impl Dfl {
fn load(vb: VarBuilder, num_classes: usize) -> Result<Self> {
let conv = conv2d_no_bias(num_classes, 1, 1, Default::default(), vb.pp("conv"))?;
- Ok(Self { conv, num_classes })
+ Ok(Self {
+ conv,
+ num_classes,
+ span: tracing::span!(tracing::Level::TRACE, "dfl"),
+ })
}
}
impl Module for Dfl {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
let (b_sz, _channels, anchors) = xs.dims3()?;
let xs = xs
.reshape((b_sz, 4, self.num_classes, anchors))?
@@ -247,6 +276,7 @@ struct DarkNet {
b4_0: ConvBlock,
b4_1: C2f,
b5: Sppf,
+ span: tracing::Span,
}
impl DarkNet {
@@ -330,10 +360,12 @@ impl DarkNet {
b4_0,
b4_1,
b5,
+ span: tracing::span!(tracing::Level::TRACE, "darknet"),
})
}
fn forward(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
+ let _enter = self.span.enter();
let x1 = self.b1_1.forward(&self.b1_0.forward(xs)?)?;
let x2 = self
.b2_2
@@ -354,6 +386,7 @@ struct YoloV8Neck {
n4: C2f,
n5: ConvBlock,
n6: C2f,
+ span: tracing::Span,
}
impl YoloV8Neck {
@@ -413,10 +446,12 @@ impl YoloV8Neck {
n4,
n5,
n6,
+ span: tracing::span!(tracing::Level::TRACE, "neck"),
})
}
fn forward(&self, p3: &Tensor, p4: &Tensor, p5: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
+ let _enter = self.span.enter();
let x = self
.n1
.forward(&Tensor::cat(&[&self.up.forward(p5)?, p4], 1)?)?;
@@ -440,6 +475,7 @@ struct DetectionHead {
cv3: [(ConvBlock, ConvBlock, Conv2d); 3],
ch: usize,
no: usize,
+ span: tracing::Span,
}
#[derive(Debug)]
@@ -447,6 +483,7 @@ struct PoseHead {
detect: DetectionHead,
cv4: [(ConvBlock, ConvBlock, Conv2d); 3],
kpt: (usize, usize),
+ span: tracing::Span,
}
fn make_anchors(
@@ -519,6 +556,7 @@ impl DetectionHead {
cv3,
ch,
no,
+ span: tracing::span!(tracing::Level::TRACE, "detection-head"),
})
}
@@ -547,6 +585,7 @@ impl DetectionHead {
}
fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result<DetectionHeadOut> {
+ let _enter = self.span.enter();
let forward_cv = |xs, i: usize| {
let xs_2 = self.cv2[i].0.forward(xs)?;
let xs_2 = self.cv2[i].1.forward(&xs_2)?;
@@ -606,7 +645,12 @@ impl PoseHead {
Self::load_cv4(vb.pp("cv4.1"), c4, nk, filters.1)?,
Self::load_cv4(vb.pp("cv4.2"), c4, nk, filters.2)?,
];
- Ok(Self { detect, cv4, kpt })
+ Ok(Self {
+ detect,
+ cv4,
+ kpt,
+ span: tracing::span!(tracing::Level::TRACE, "pose-head"),
+ })
}
fn load_cv4(
@@ -622,6 +666,7 @@ impl PoseHead {
}
fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
let d = self.detect.forward(xs0, xs1, xs2)?;
let forward_cv = |xs: &Tensor, i: usize| {
let (b_sz, _, h, w) = xs.dims4()?;
@@ -650,6 +695,7 @@ pub struct YoloV8 {
net: DarkNet,
fpn: YoloV8Neck,
head: DetectionHead,
+ span: tracing::Span,
}
impl YoloV8 {
@@ -657,12 +703,18 @@ impl YoloV8 {
let net = DarkNet::load(vb.pp("net"), m)?;
let fpn = YoloV8Neck::load(vb.pp("fpn"), m)?;
let head = DetectionHead::load(vb.pp("head"), num_classes, m.filters())?;
- Ok(Self { net, fpn, head })
+ Ok(Self {
+ net,
+ fpn,
+ head,
+ span: tracing::span!(tracing::Level::TRACE, "yolo-v8"),
+ })
}
}
impl Module for YoloV8 {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
let (xs1, xs2, xs3) = self.net.forward(xs)?;
let (xs1, xs2, xs3) = self.fpn.forward(&xs1, &xs2, &xs3)?;
Ok(self.head.forward(&xs1, &xs2, &xs3)?.pred)
@@ -674,6 +726,7 @@ pub struct YoloV8Pose {
net: DarkNet,
fpn: YoloV8Neck,
head: PoseHead,
+ span: tracing::Span,
}
impl YoloV8Pose {
@@ -686,12 +739,18 @@ impl YoloV8Pose {
let net = DarkNet::load(vb.pp("net"), m)?;
let fpn = YoloV8Neck::load(vb.pp("fpn"), m)?;
let head = PoseHead::load(vb.pp("head"), num_classes, kpt, m.filters())?;
- Ok(Self { net, fpn, head })
+ Ok(Self {
+ net,
+ fpn,
+ head,
+ span: tracing::span!(tracing::Level::TRACE, "yolo-v8-pose"),
+ })
}
}
impl Module for YoloV8Pose {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
let (xs1, xs2, xs3) = self.net.forward(xs)?;
let (xs1, xs2, xs3) = self.fpn.forward(&xs1, &xs2, &xs3)?;
self.head.forward(&xs1, &xs2, &xs3)