diff options
-rw-r--r-- | candle-examples/examples/segment-anything/main.rs | 16 | ||||
-rw-r--r-- | candle-examples/examples/segment-anything/model_mask_decoder.rs | 2 | ||||
-rw-r--r-- | candle-examples/examples/segment-anything/model_sam.rs | 103 | ||||
-rw-r--r-- | candle-examples/examples/segment-anything/model_transformer.rs | 6 | ||||
-rw-r--r-- | candle-examples/examples/yolo-v3/main.rs | 4 | ||||
-rw-r--r-- | candle-examples/examples/yolo-v8/main.rs | 12 | ||||
-rw-r--r-- | candle-examples/src/object_detection.rs | 8 |
7 files changed, 125 insertions, 26 deletions
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs index a749ba2a..4627248c 100644 --- a/candle-examples/examples/segment-anything/main.rs +++ b/candle-examples/examples/segment-anything/main.rs @@ -188,13 +188,25 @@ pub fn main() -> anyhow::Result<()> { if args.generate_masks { // Default options similar to the Python version. - sam.generate_masks( + let bboxes = sam.generate_masks( &image, /* points_per_side */ 32, /* crop_n_layer */ 0, /* crop_overlap_ratio */ 512. / 1500., /* crop_n_points_downscale_factor */ 1, - )? + )?; + for (idx, bbox) in bboxes.iter().enumerate() { + println!("{bbox:?}"); + let mask = (&bbox.data.to_dtype(DType::U8)? * 255.)?; + let (h, w) = mask.dims2()?; + let mask = mask.broadcast_as((3, h, w))?; + candle_examples::save_image_resize( + &mask, + format!("sam_mask{idx}.png"), + initial_h, + initial_w, + )?; + } } else { let point = Some((args.point_x, args.point_y)); let (mask, iou_predictions) = sam.forward(&image, point, false)?; diff --git a/candle-examples/examples/segment-anything/model_mask_decoder.rs b/candle-examples/examples/segment-anything/model_mask_decoder.rs index 1f6d62a4..c02b44a7 100644 --- a/candle-examples/examples/segment-anything/model_mask_decoder.rs +++ b/candle-examples/examples/segment-anything/model_mask_decoder.rs @@ -219,7 +219,7 @@ impl MaskDecoder { let h = mlp.forward(&mask_tokens_out.i((.., i))?)?; hyper_in_list.push(h) } - let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?; + let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?.contiguous()?; let (b, c, h, w) = upscaled_embedding.dims4()?; let masks = hyper_in.matmul(&upscaled_embedding.reshape((b, c, h * w))?)?; let masks = masks.reshape((b, (), h, w))?; diff --git a/candle-examples/examples/segment-anything/model_sam.rs b/candle-examples/examples/segment-anything/model_sam.rs index 884559af..ade976c1 100644 --- a/candle-examples/examples/segment-anything/model_sam.rs +++ b/candle-examples/examples/segment-anything/model_sam.rs @@ -8,6 +8,11 @@ use crate::model_prompt_encoder::PromptEncoder; const PROMPT_EMBED_DIM: usize = 256; pub const IMAGE_SIZE: usize = 1024; const VIT_PATCH_SIZE: usize = 16; +const PRED_IOU_THRESH: f32 = 0.88; +const STABILITY_SCORE_OFFSET: f32 = 1.0; +const STABILITY_SCORE_THRESHOLD: f32 = 0.95; +const MODEL_MASK_THRESHOLD: f32 = 0.0; +const CROP_NMS_THRESH: f32 = 0.7; #[derive(Debug)] pub struct Sam { @@ -129,7 +134,12 @@ impl Sam { img.pad_with_zeros(2, 0, IMAGE_SIZE - w) } - fn process_crop(&self, img: &Tensor, cb: CropBox, point_grids: &[(f64, f64)]) -> Result<()> { + fn process_crop( + &self, + img: &Tensor, + cb: CropBox, + point_grids: &[(f64, f64)], + ) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> { // Crop the image and calculate embeddings. let img = img.i((.., cb.y0..cb.y1, cb.x0..cb.x1))?; let img = self.preprocess(&img)?.unsqueeze(0)?; @@ -144,28 +154,86 @@ impl Sam { .iter() .map(|&(x, y)| vec![x as f32 * crop_w as f32, y as f32 * crop_h as f32]) .collect::<Vec<_>>(); + + let mut bboxes = Vec::new(); for points in points.chunks(64) { + // Run the model on this batch. let points_len = points.len(); let in_points = Tensor::new(points.to_vec(), img.device())?.unsqueeze(1)?; let in_labels = Tensor::ones((points_len, 1), DType::F32, img.device())?; let (sparse_prompt_embeddings, dense_prompt_embeddings) = self.prompt_encoder .forward(Some((&in_points, &in_labels)), None, None)?; - let (_low_res_mask, iou_predictions) = self.mask_decoder.forward( + + let (low_res_mask, iou_predictions) = self.mask_decoder.forward( &img_embeddings, &image_pe, &sparse_prompt_embeddings, &dense_prompt_embeddings, /* multimask_output */ true, )?; + let low_res_mask = low_res_mask.flatten(0, 1)?; + let iou_predictions = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?; + let dev = low_res_mask.device(); + + for (i, iou) in iou_predictions.iter().enumerate() { + // Filter by predicted IoU. + if *iou < PRED_IOU_THRESH { + continue; + } + let low_res_mask = low_res_mask.get(i)?; + + // Calculate stability score. + let bound = Tensor::new(MODEL_MASK_THRESHOLD + STABILITY_SCORE_OFFSET, dev)? + .broadcast_as(low_res_mask.shape())?; + let intersections = low_res_mask + .ge(&bound)? + .to_dtype(DType::F32)? + .sum_all()? + .to_vec0::<f32>()?; + let bound = Tensor::new(MODEL_MASK_THRESHOLD - STABILITY_SCORE_OFFSET, dev)? + .broadcast_as(low_res_mask.shape())?; + let unions = low_res_mask + .ge(&bound)? + .to_dtype(DType::F32)? + .sum_all()? + .to_vec0::<f32>()?; + let stability_score = intersections / unions; + if stability_score < STABILITY_SCORE_THRESHOLD { + continue; + } - println!("{cb:?} {iou_predictions}"); + // Threshold masks and calculate boxes. + let low_res_mask = low_res_mask + .ge(&Tensor::new(0f32, dev)?.broadcast_as(low_res_mask.shape())?)? + .to_dtype(DType::U32)?; + let low_res_mask_per_x = low_res_mask.sum(0)?.to_vec1::<u32>()?; + let low_res_mask_per_y = low_res_mask.sum(1)?.to_vec1::<u32>()?; + let min_max_x = min_max_indexes(&low_res_mask_per_x); + let min_max_y = min_max_indexes(&low_res_mask_per_y); + if let Some(((x0, x1), (y0, y1))) = min_max_x.zip(min_max_y) { + let bbox = candle_examples::object_detection::Bbox { + xmin: x0 as f32, + ymin: y0 as f32, + xmax: x1 as f32, + ymax: y1 as f32, + confidence: *iou, + data: low_res_mask, + }; + bboxes.push(bbox); + } + // TODO: + // Filter boxes that touch crop boundaries + // Compress to RLE. + } } + let mut bboxes = vec![bboxes]; // Remove duplicates within this crop. + candle_examples::object_detection::non_maximum_suppression(&mut bboxes, CROP_NMS_THRESH); - // Return to the original image frame. - Ok(()) + // TODO: Return to the original image frame. + Ok(bboxes.remove(0)) } pub fn generate_masks( @@ -175,7 +243,7 @@ impl Sam { crop_n_layer: usize, crop_overlap_ratio: f64, crop_n_points_downscale_factor: usize, - ) -> Result<()> { + ) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> { let (_c, h, w) = img.dims3()?; let point_grids = build_all_layer_point_grids( points_per_side, @@ -183,12 +251,31 @@ impl Sam { crop_n_points_downscale_factor, ); let crop_boxes = generate_crop_boxes((h, w), crop_n_layer, crop_overlap_ratio); + let mut bboxes = Vec::new(); for crop_box in crop_boxes.into_iter() { let layer_idx = crop_box.layer_idx; - self.process_crop(img, crop_box, &point_grids[layer_idx])? + let b = self.process_crop(img, crop_box, &point_grids[layer_idx])?; + bboxes.extend(b) } // TODO: remove duplicates - Ok(()) + Ok(bboxes) + } +} + +// Return the first and last indexes i for which values[i] > 0 +fn min_max_indexes(values: &[u32]) -> Option<(usize, usize)> { + let (mut min_i, mut max_i) = (usize::MAX, usize::MIN); + for (i, &s) in values.iter().enumerate() { + if s == 0 { + continue; + } + min_i = usize::min(i, min_i); + max_i = usize::max(i, max_i); + } + if max_i < min_i { + None + } else { + Some((min_i, max_i)) } } diff --git a/candle-examples/examples/segment-anything/model_transformer.rs b/candle-examples/examples/segment-anything/model_transformer.rs index e4de27cb..e12aac08 100644 --- a/candle-examples/examples/segment-anything/model_transformer.rs +++ b/candle-examples/examples/segment-anything/model_transformer.rs @@ -45,9 +45,9 @@ impl Attention { } fn forward(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> { - let q = self.q_proj.forward(q)?; - let k = self.k_proj.forward(k)?; - let v = self.v_proj.forward(v)?; + let q = self.q_proj.forward(&q.contiguous()?)?; + let k = self.k_proj.forward(&k.contiguous()?)?; + let v = self.v_proj.forward(&v.contiguous()?)?; let q = self.separate_heads(&q)?; let k = self.separate_heads(&k)?; diff --git a/candle-examples/examples/yolo-v3/main.rs b/candle-examples/examples/yolo-v3/main.rs index 5e388921..20021b45 100644 --- a/candle-examples/examples/yolo-v3/main.rs +++ b/candle-examples/examples/yolo-v3/main.rs @@ -46,7 +46,7 @@ pub fn report( let (npreds, pred_size) = pred.dims2()?; let nclasses = pred_size - 5; // The bounding boxes grouped by (maximum) class index. - let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect(); + let mut bboxes: Vec<Vec<Bbox<()>>> = (0..nclasses).map(|_| vec![]).collect(); // Extract the bounding boxes for which confidence is above the threshold. for index in 0..npreds { let pred = Vec::<f32>::try_from(pred.get(index)?)?; @@ -65,7 +65,7 @@ pub fn report( xmax: pred[0] + pred[2] / 2., ymax: pred[1] + pred[3] / 2., confidence, - keypoints: vec![], + data: (), }; bboxes[class_index].push(bbox) } diff --git a/candle-examples/examples/yolo-v8/main.rs b/candle-examples/examples/yolo-v8/main.rs index d5c5ac1c..2017b5be 100644 --- a/candle-examples/examples/yolo-v8/main.rs +++ b/candle-examples/examples/yolo-v8/main.rs @@ -64,7 +64,7 @@ pub fn report_detect( let (pred_size, npreds) = pred.dims2()?; let nclasses = pred_size - 4; // The bounding boxes grouped by (maximum) class index. - let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect(); + let mut bboxes: Vec<Vec<Bbox<Vec<KeyPoint>>>> = (0..nclasses).map(|_| vec![]).collect(); // Extract the bounding boxes for which confidence is above the threshold. for index in 0..npreds { let pred = Vec::<f32>::try_from(pred.i((.., index))?)?; @@ -83,7 +83,7 @@ pub fn report_detect( xmax: pred[0] + pred[2] / 2., ymax: pred[1] + pred[3] / 2., confidence, - keypoints: vec![], + data: vec![], }; bboxes[class_index].push(bbox) } @@ -176,7 +176,7 @@ pub fn report_pose( xmax: pred[0] + pred[2] / 2., ymax: pred[1] + pred[3] / 2., confidence, - keypoints, + data: keypoints, }; bboxes.push(bbox) } @@ -204,7 +204,7 @@ pub fn report_pose( image::Rgb([255, 0, 0]), ); } - for kp in b.keypoints.iter() { + for kp in b.data.iter() { if kp.mask < 0.6 { continue; } @@ -219,8 +219,8 @@ pub fn report_pose( } for &(idx1, idx2) in KP_CONNECTIONS.iter() { - let kp1 = &b.keypoints[idx1]; - let kp2 = &b.keypoints[idx2]; + let kp1 = &b.data[idx1]; + let kp2 = &b.data[idx2]; if kp1.mask < 0.6 || kp2.mask < 0.6 { continue; } diff --git a/candle-examples/src/object_detection.rs b/candle-examples/src/object_detection.rs index c7c60136..ce579316 100644 --- a/candle-examples/src/object_detection.rs +++ b/candle-examples/src/object_detection.rs @@ -1,12 +1,12 @@ /// A bounding box around an object. #[derive(Debug, Clone)] -pub struct Bbox { +pub struct Bbox<D> { pub xmin: f32, pub ymin: f32, pub xmax: f32, pub ymax: f32, pub confidence: f32, - pub keypoints: Vec<KeyPoint>, + pub data: D, } #[derive(Debug, Clone, Copy, PartialEq)] @@ -17,7 +17,7 @@ pub struct KeyPoint { } /// Intersection over union of two bounding boxes. -pub fn iou(b1: &Bbox, b2: &Bbox) -> f32 { +pub fn iou<D>(b1: &Bbox<D>, b2: &Bbox<D>) -> f32 { let b1_area = (b1.xmax - b1.xmin + 1.) * (b1.ymax - b1.ymin + 1.); let b2_area = (b2.xmax - b2.xmin + 1.) * (b2.ymax - b2.ymin + 1.); let i_xmin = b1.xmin.max(b2.xmin); @@ -28,7 +28,7 @@ pub fn iou(b1: &Bbox, b2: &Bbox) -> f32 { i_area / (b1_area + b2_area - i_area) } -pub fn non_maximum_suppression(bboxes: &mut [Vec<Bbox>], threshold: f32) { +pub fn non_maximum_suppression<D>(bboxes: &mut [Vec<Bbox<D>>], threshold: f32) { // Perform non-maximum suppression. for bboxes_for_class in bboxes.iter_mut() { bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap()); |