summaryrefslogtreecommitdiff
path: root/candle-examples/examples/yolo-v8/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/yolo-v8/main.rs')
-rw-r--r--candle-examples/examples/yolo-v8/main.rs12
1 files changed, 6 insertions, 6 deletions
diff --git a/candle-examples/examples/yolo-v8/main.rs b/candle-examples/examples/yolo-v8/main.rs
index d5c5ac1c..2017b5be 100644
--- a/candle-examples/examples/yolo-v8/main.rs
+++ b/candle-examples/examples/yolo-v8/main.rs
@@ -64,7 +64,7 @@ pub fn report_detect(
let (pred_size, npreds) = pred.dims2()?;
let nclasses = pred_size - 4;
// The bounding boxes grouped by (maximum) class index.
- let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect();
+ let mut bboxes: Vec<Vec<Bbox<Vec<KeyPoint>>>> = (0..nclasses).map(|_| vec![]).collect();
// Extract the bounding boxes for which confidence is above the threshold.
for index in 0..npreds {
let pred = Vec::<f32>::try_from(pred.i((.., index))?)?;
@@ -83,7 +83,7 @@ pub fn report_detect(
xmax: pred[0] + pred[2] / 2.,
ymax: pred[1] + pred[3] / 2.,
confidence,
- keypoints: vec![],
+ data: vec![],
};
bboxes[class_index].push(bbox)
}
@@ -176,7 +176,7 @@ pub fn report_pose(
xmax: pred[0] + pred[2] / 2.,
ymax: pred[1] + pred[3] / 2.,
confidence,
- keypoints,
+ data: keypoints,
};
bboxes.push(bbox)
}
@@ -204,7 +204,7 @@ pub fn report_pose(
image::Rgb([255, 0, 0]),
);
}
- for kp in b.keypoints.iter() {
+ for kp in b.data.iter() {
if kp.mask < 0.6 {
continue;
}
@@ -219,8 +219,8 @@ pub fn report_pose(
}
for &(idx1, idx2) in KP_CONNECTIONS.iter() {
- let kp1 = &b.keypoints[idx1];
- let kp2 = &b.keypoints[idx2];
+ let kp1 = &b.data[idx1];
+ let kp2 = &b.data[idx2];
if kp1.mask < 0.6 || kp2.mask < 0.6 {
continue;
}