diff options
Diffstat (limited to 'candle-examples/examples/yolo-v8/main.rs')
-rw-r--r-- | candle-examples/examples/yolo-v8/main.rs | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/candle-examples/examples/yolo-v8/main.rs b/candle-examples/examples/yolo-v8/main.rs index af8cf98a..c65a5ca1 100644 --- a/candle-examples/examples/yolo-v8/main.rs +++ b/candle-examples/examples/yolo-v8/main.rs @@ -7,7 +7,7 @@ extern crate accelerate_src; mod model; use model::{Multiples, YoloV8, YoloV8Pose}; -use candle::{DType, IndexOp, Result, Tensor}; +use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::{Module, VarBuilder}; use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint}; use clap::{Parser, ValueEnum}; @@ -61,6 +61,7 @@ pub fn report_detect( nms_threshold: f32, legend_size: u32, ) -> Result<DynamicImage> { + let pred = pred.to_device(&Device::Cpu)?; let (pred_size, npreds) = pred.dims2()?; let nclasses = pred_size - 4; // The bounding boxes grouped by (maximum) class index. @@ -153,6 +154,7 @@ pub fn report_pose( confidence_threshold: f32, nms_threshold: f32, ) -> Result<DynamicImage> { + let pred = pred.to_device(&Device::Cpu)?; let (pred_size, npreds) = pred.dims2()?; if pred_size != 17 * 3 + 4 + 1 { candle::bail!("unexpected pred-size {pred_size}"); |