diff options
Diffstat (limited to 'candle-wasm-examples/yolo/src/worker.rs')
-rw-r--r-- | candle-wasm-examples/yolo/src/worker.rs | 84 |
1 files changed, 82 insertions, 2 deletions
diff --git a/candle-wasm-examples/yolo/src/worker.rs b/candle-wasm-examples/yolo/src/worker.rs index 0f4cd6f2..11d41c53 100644 --- a/candle-wasm-examples/yolo/src/worker.rs +++ b/candle-wasm-examples/yolo/src/worker.rs @@ -1,4 +1,4 @@ -use crate::model::{report, Bbox, Multiples, YoloV8}; +use crate::model::{report_detect, report_pose, Bbox, Multiples, YoloV8, YoloV8Pose}; use candle::{DType, Device, Result, Tensor}; use candle_nn::{Module, VarBuilder}; use serde::{Deserialize, Serialize}; @@ -81,7 +81,7 @@ impl Model { let image_t = (image_t.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?; let predictions = self.model.forward(&image_t)?.squeeze(0)?; console_log!("generated predictions {predictions:?}"); - let bboxes = report( + let bboxes = report_detect( &predictions, original_image, width, @@ -115,6 +115,86 @@ impl Model { } } +pub struct ModelPose { + model: YoloV8Pose, +} + +impl ModelPose { + pub fn run( + &self, + image_data: Vec<u8>, + conf_threshold: f32, + iou_threshold: f32, + ) -> Result<Vec<Bbox>> { + console_log!("image data: {}", image_data.len()); + let image_data = std::io::Cursor::new(image_data); + let original_image = image::io::Reader::new(image_data) + .with_guessed_format()? + .decode() + .map_err(candle::Error::wrap)?; + let (width, height) = { + let w = original_image.width() as usize; + let h = original_image.height() as usize; + if w < h { + let w = w * 640 / h; + // Sizes have to be divisible by 32. + (w / 32 * 32, 640) + } else { + let h = h * 640 / w; + (640, h / 32 * 32) + } + }; + let image_t = { + let img = original_image.resize_exact( + width as u32, + height as u32, + image::imageops::FilterType::CatmullRom, + ); + let data = img.to_rgb8().into_raw(); + Tensor::from_vec( + data, + (img.height() as usize, img.width() as usize, 3), + &Device::Cpu, + )? + .permute((2, 0, 1))? + }; + let image_t = (image_t.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?; + let predictions = self.model.forward(&image_t)?.squeeze(0)?; + console_log!("generated predictions {predictions:?}"); + let bboxes = report_pose( + &predictions, + original_image, + width, + height, + conf_threshold, + iou_threshold, + )?; + Ok(bboxes) + } + + pub fn load_(weights: &[u8], model_size: &str) -> Result<Self> { + let multiples = match model_size { + "n" => Multiples::n(), + "s" => Multiples::s(), + "m" => Multiples::m(), + "l" => Multiples::l(), + "x" => Multiples::x(), + _ => Err(candle::Error::Msg( + "invalid model size: must be n, s, m, l or x".to_string(), + ))?, + }; + let dev = &Device::Cpu; + let weights = safetensors::tensor::SafeTensors::deserialize(weights)?; + let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, dev); + let model = YoloV8Pose::load(vb, multiples, 1, (17, 3))?; + Ok(Self { model }) + } + + pub fn load(md: ModelData) -> Result<Self> { + Self::load_(&md.weights, &md.model_size.to_string()) + } +} + pub struct Worker { link: WorkerLink<Self>, model: Option<Model>, |