diff options
Diffstat (limited to 'candle-examples/examples/segment-anything/model_sam.rs')
-rw-r--r-- | candle-examples/examples/segment-anything/model_sam.rs | 103 |
1 files changed, 95 insertions, 8 deletions
diff --git a/candle-examples/examples/segment-anything/model_sam.rs b/candle-examples/examples/segment-anything/model_sam.rs index 884559af..ade976c1 100644 --- a/candle-examples/examples/segment-anything/model_sam.rs +++ b/candle-examples/examples/segment-anything/model_sam.rs @@ -8,6 +8,11 @@ use crate::model_prompt_encoder::PromptEncoder; const PROMPT_EMBED_DIM: usize = 256; pub const IMAGE_SIZE: usize = 1024; const VIT_PATCH_SIZE: usize = 16; +const PRED_IOU_THRESH: f32 = 0.88; +const STABILITY_SCORE_OFFSET: f32 = 1.0; +const STABILITY_SCORE_THRESHOLD: f32 = 0.95; +const MODEL_MASK_THRESHOLD: f32 = 0.0; +const CROP_NMS_THRESH: f32 = 0.7; #[derive(Debug)] pub struct Sam { @@ -129,7 +134,12 @@ impl Sam { img.pad_with_zeros(2, 0, IMAGE_SIZE - w) } - fn process_crop(&self, img: &Tensor, cb: CropBox, point_grids: &[(f64, f64)]) -> Result<()> { + fn process_crop( + &self, + img: &Tensor, + cb: CropBox, + point_grids: &[(f64, f64)], + ) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> { // Crop the image and calculate embeddings. let img = img.i((.., cb.y0..cb.y1, cb.x0..cb.x1))?; let img = self.preprocess(&img)?.unsqueeze(0)?; @@ -144,28 +154,86 @@ impl Sam { .iter() .map(|&(x, y)| vec![x as f32 * crop_w as f32, y as f32 * crop_h as f32]) .collect::<Vec<_>>(); + + let mut bboxes = Vec::new(); for points in points.chunks(64) { + // Run the model on this batch. let points_len = points.len(); let in_points = Tensor::new(points.to_vec(), img.device())?.unsqueeze(1)?; let in_labels = Tensor::ones((points_len, 1), DType::F32, img.device())?; let (sparse_prompt_embeddings, dense_prompt_embeddings) = self.prompt_encoder .forward(Some((&in_points, &in_labels)), None, None)?; - let (_low_res_mask, iou_predictions) = self.mask_decoder.forward( + + let (low_res_mask, iou_predictions) = self.mask_decoder.forward( &img_embeddings, &image_pe, &sparse_prompt_embeddings, &dense_prompt_embeddings, /* multimask_output */ true, )?; + let low_res_mask = low_res_mask.flatten(0, 1)?; + let iou_predictions = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?; + let dev = low_res_mask.device(); + + for (i, iou) in iou_predictions.iter().enumerate() { + // Filter by predicted IoU. + if *iou < PRED_IOU_THRESH { + continue; + } + let low_res_mask = low_res_mask.get(i)?; + + // Calculate stability score. + let bound = Tensor::new(MODEL_MASK_THRESHOLD + STABILITY_SCORE_OFFSET, dev)? + .broadcast_as(low_res_mask.shape())?; + let intersections = low_res_mask + .ge(&bound)? + .to_dtype(DType::F32)? + .sum_all()? + .to_vec0::<f32>()?; + let bound = Tensor::new(MODEL_MASK_THRESHOLD - STABILITY_SCORE_OFFSET, dev)? + .broadcast_as(low_res_mask.shape())?; + let unions = low_res_mask + .ge(&bound)? + .to_dtype(DType::F32)? + .sum_all()? + .to_vec0::<f32>()?; + let stability_score = intersections / unions; + if stability_score < STABILITY_SCORE_THRESHOLD { + continue; + } - println!("{cb:?} {iou_predictions}"); + // Threshold masks and calculate boxes. + let low_res_mask = low_res_mask + .ge(&Tensor::new(0f32, dev)?.broadcast_as(low_res_mask.shape())?)? + .to_dtype(DType::U32)?; + let low_res_mask_per_x = low_res_mask.sum(0)?.to_vec1::<u32>()?; + let low_res_mask_per_y = low_res_mask.sum(1)?.to_vec1::<u32>()?; + let min_max_x = min_max_indexes(&low_res_mask_per_x); + let min_max_y = min_max_indexes(&low_res_mask_per_y); + if let Some(((x0, x1), (y0, y1))) = min_max_x.zip(min_max_y) { + let bbox = candle_examples::object_detection::Bbox { + xmin: x0 as f32, + ymin: y0 as f32, + xmax: x1 as f32, + ymax: y1 as f32, + confidence: *iou, + data: low_res_mask, + }; + bboxes.push(bbox); + } + // TODO: + // Filter boxes that touch crop boundaries + // Compress to RLE. + } } + let mut bboxes = vec![bboxes]; // Remove duplicates within this crop. + candle_examples::object_detection::non_maximum_suppression(&mut bboxes, CROP_NMS_THRESH); - // Return to the original image frame. - Ok(()) + // TODO: Return to the original image frame. + Ok(bboxes.remove(0)) } pub fn generate_masks( @@ -175,7 +243,7 @@ impl Sam { crop_n_layer: usize, crop_overlap_ratio: f64, crop_n_points_downscale_factor: usize, - ) -> Result<()> { + ) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> { let (_c, h, w) = img.dims3()?; let point_grids = build_all_layer_point_grids( points_per_side, @@ -183,12 +251,31 @@ impl Sam { crop_n_points_downscale_factor, ); let crop_boxes = generate_crop_boxes((h, w), crop_n_layer, crop_overlap_ratio); + let mut bboxes = Vec::new(); for crop_box in crop_boxes.into_iter() { let layer_idx = crop_box.layer_idx; - self.process_crop(img, crop_box, &point_grids[layer_idx])? + let b = self.process_crop(img, crop_box, &point_grids[layer_idx])?; + bboxes.extend(b) } // TODO: remove duplicates - Ok(()) + Ok(bboxes) + } +} + +// Return the first and last indexes i for which values[i] > 0 +fn min_max_indexes(values: &[u32]) -> Option<(usize, usize)> { + let (mut min_i, mut max_i) = (usize::MAX, usize::MIN); + for (i, &s) in values.iter().enumerate() { + if s == 0 { + continue; + } + min_i = usize::min(i, min_i); + max_i = usize::max(i, max_i); + } + if max_i < min_i { + None + } else { + Some((min_i, max_i)) } } |