summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything/model_sam.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/segment-anything/model_sam.rs')
-rw-r--r--candle-examples/examples/segment-anything/model_sam.rs103
1 files changed, 95 insertions, 8 deletions
diff --git a/candle-examples/examples/segment-anything/model_sam.rs b/candle-examples/examples/segment-anything/model_sam.rs
index 884559af..ade976c1 100644
--- a/candle-examples/examples/segment-anything/model_sam.rs
+++ b/candle-examples/examples/segment-anything/model_sam.rs
@@ -8,6 +8,11 @@ use crate::model_prompt_encoder::PromptEncoder;
const PROMPT_EMBED_DIM: usize = 256;
pub const IMAGE_SIZE: usize = 1024;
const VIT_PATCH_SIZE: usize = 16;
+const PRED_IOU_THRESH: f32 = 0.88;
+const STABILITY_SCORE_OFFSET: f32 = 1.0;
+const STABILITY_SCORE_THRESHOLD: f32 = 0.95;
+const MODEL_MASK_THRESHOLD: f32 = 0.0;
+const CROP_NMS_THRESH: f32 = 0.7;
#[derive(Debug)]
pub struct Sam {
@@ -129,7 +134,12 @@ impl Sam {
img.pad_with_zeros(2, 0, IMAGE_SIZE - w)
}
- fn process_crop(&self, img: &Tensor, cb: CropBox, point_grids: &[(f64, f64)]) -> Result<()> {
+ fn process_crop(
+ &self,
+ img: &Tensor,
+ cb: CropBox,
+ point_grids: &[(f64, f64)],
+ ) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> {
// Crop the image and calculate embeddings.
let img = img.i((.., cb.y0..cb.y1, cb.x0..cb.x1))?;
let img = self.preprocess(&img)?.unsqueeze(0)?;
@@ -144,28 +154,86 @@ impl Sam {
.iter()
.map(|&(x, y)| vec![x as f32 * crop_w as f32, y as f32 * crop_h as f32])
.collect::<Vec<_>>();
+
+ let mut bboxes = Vec::new();
for points in points.chunks(64) {
+ // Run the model on this batch.
let points_len = points.len();
let in_points = Tensor::new(points.to_vec(), img.device())?.unsqueeze(1)?;
let in_labels = Tensor::ones((points_len, 1), DType::F32, img.device())?;
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
self.prompt_encoder
.forward(Some((&in_points, &in_labels)), None, None)?;
- let (_low_res_mask, iou_predictions) = self.mask_decoder.forward(
+
+ let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
&img_embeddings,
&image_pe,
&sparse_prompt_embeddings,
&dense_prompt_embeddings,
/* multimask_output */ true,
)?;
+ let low_res_mask = low_res_mask.flatten(0, 1)?;
+ let iou_predictions = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?;
+ let dev = low_res_mask.device();
+
+ for (i, iou) in iou_predictions.iter().enumerate() {
+ // Filter by predicted IoU.
+ if *iou < PRED_IOU_THRESH {
+ continue;
+ }
+ let low_res_mask = low_res_mask.get(i)?;
+
+ // Calculate stability score.
+ let bound = Tensor::new(MODEL_MASK_THRESHOLD + STABILITY_SCORE_OFFSET, dev)?
+ .broadcast_as(low_res_mask.shape())?;
+ let intersections = low_res_mask
+ .ge(&bound)?
+ .to_dtype(DType::F32)?
+ .sum_all()?
+ .to_vec0::<f32>()?;
+ let bound = Tensor::new(MODEL_MASK_THRESHOLD - STABILITY_SCORE_OFFSET, dev)?
+ .broadcast_as(low_res_mask.shape())?;
+ let unions = low_res_mask
+ .ge(&bound)?
+ .to_dtype(DType::F32)?
+ .sum_all()?
+ .to_vec0::<f32>()?;
+ let stability_score = intersections / unions;
+ if stability_score < STABILITY_SCORE_THRESHOLD {
+ continue;
+ }
- println!("{cb:?} {iou_predictions}");
+ // Threshold masks and calculate boxes.
+ let low_res_mask = low_res_mask
+ .ge(&Tensor::new(0f32, dev)?.broadcast_as(low_res_mask.shape())?)?
+ .to_dtype(DType::U32)?;
+ let low_res_mask_per_x = low_res_mask.sum(0)?.to_vec1::<u32>()?;
+ let low_res_mask_per_y = low_res_mask.sum(1)?.to_vec1::<u32>()?;
+ let min_max_x = min_max_indexes(&low_res_mask_per_x);
+ let min_max_y = min_max_indexes(&low_res_mask_per_y);
+ if let Some(((x0, x1), (y0, y1))) = min_max_x.zip(min_max_y) {
+ let bbox = candle_examples::object_detection::Bbox {
+ xmin: x0 as f32,
+ ymin: y0 as f32,
+ xmax: x1 as f32,
+ ymax: y1 as f32,
+ confidence: *iou,
+ data: low_res_mask,
+ };
+ bboxes.push(bbox);
+ }
+ // TODO:
+ // Filter boxes that touch crop boundaries
+ // Compress to RLE.
+ }
}
+ let mut bboxes = vec![bboxes];
// Remove duplicates within this crop.
+ candle_examples::object_detection::non_maximum_suppression(&mut bboxes, CROP_NMS_THRESH);
- // Return to the original image frame.
- Ok(())
+ // TODO: Return to the original image frame.
+ Ok(bboxes.remove(0))
}
pub fn generate_masks(
@@ -175,7 +243,7 @@ impl Sam {
crop_n_layer: usize,
crop_overlap_ratio: f64,
crop_n_points_downscale_factor: usize,
- ) -> Result<()> {
+ ) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> {
let (_c, h, w) = img.dims3()?;
let point_grids = build_all_layer_point_grids(
points_per_side,
@@ -183,12 +251,31 @@ impl Sam {
crop_n_points_downscale_factor,
);
let crop_boxes = generate_crop_boxes((h, w), crop_n_layer, crop_overlap_ratio);
+ let mut bboxes = Vec::new();
for crop_box in crop_boxes.into_iter() {
let layer_idx = crop_box.layer_idx;
- self.process_crop(img, crop_box, &point_grids[layer_idx])?
+ let b = self.process_crop(img, crop_box, &point_grids[layer_idx])?;
+ bboxes.extend(b)
}
// TODO: remove duplicates
- Ok(())
+ Ok(bboxes)
+ }
+}
+
+// Return the first and last indexes i for which values[i] > 0
+fn min_max_indexes(values: &[u32]) -> Option<(usize, usize)> {
+ let (mut min_i, mut max_i) = (usize::MAX, usize::MIN);
+ for (i, &s) in values.iter().enumerate() {
+ if s == 0 {
+ continue;
+ }
+ min_i = usize::min(i, min_i);
+ max_i = usize::max(i, max_i);
+ }
+ if max_i < min_i {
+ None
+ } else {
+ Some((min_i, max_i))
}
}