diff options
Diffstat (limited to 'candle-examples/examples/segment-anything')
4 files changed, 113 insertions, 14 deletions
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs index a749ba2a..4627248c 100644 --- a/candle-examples/examples/segment-anything/main.rs +++ b/candle-examples/examples/segment-anything/main.rs @@ -188,13 +188,25 @@ pub fn main() -> anyhow::Result<()> { if args.generate_masks { // Default options similar to the Python version. - sam.generate_masks( + let bboxes = sam.generate_masks( &image, /* points_per_side */ 32, /* crop_n_layer */ 0, /* crop_overlap_ratio */ 512. / 1500., /* crop_n_points_downscale_factor */ 1, - )? + )?; + for (idx, bbox) in bboxes.iter().enumerate() { + println!("{bbox:?}"); + let mask = (&bbox.data.to_dtype(DType::U8)? * 255.)?; + let (h, w) = mask.dims2()?; + let mask = mask.broadcast_as((3, h, w))?; + candle_examples::save_image_resize( + &mask, + format!("sam_mask{idx}.png"), + initial_h, + initial_w, + )?; + } } else { let point = Some((args.point_x, args.point_y)); let (mask, iou_predictions) = sam.forward(&image, point, false)?; diff --git a/candle-examples/examples/segment-anything/model_mask_decoder.rs b/candle-examples/examples/segment-anything/model_mask_decoder.rs index 1f6d62a4..c02b44a7 100644 --- a/candle-examples/examples/segment-anything/model_mask_decoder.rs +++ b/candle-examples/examples/segment-anything/model_mask_decoder.rs @@ -219,7 +219,7 @@ impl MaskDecoder { let h = mlp.forward(&mask_tokens_out.i((.., i))?)?; hyper_in_list.push(h) } - let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?; + let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?.contiguous()?; let (b, c, h, w) = upscaled_embedding.dims4()?; let masks = hyper_in.matmul(&upscaled_embedding.reshape((b, c, h * w))?)?; let masks = masks.reshape((b, (), h, w))?; diff --git a/candle-examples/examples/segment-anything/model_sam.rs b/candle-examples/examples/segment-anything/model_sam.rs index 884559af..ade976c1 100644 --- a/candle-examples/examples/segment-anything/model_sam.rs +++ b/candle-examples/examples/segment-anything/model_sam.rs @@ -8,6 +8,11 @@ use crate::model_prompt_encoder::PromptEncoder; const PROMPT_EMBED_DIM: usize = 256; pub const IMAGE_SIZE: usize = 1024; const VIT_PATCH_SIZE: usize = 16; +const PRED_IOU_THRESH: f32 = 0.88; +const STABILITY_SCORE_OFFSET: f32 = 1.0; +const STABILITY_SCORE_THRESHOLD: f32 = 0.95; +const MODEL_MASK_THRESHOLD: f32 = 0.0; +const CROP_NMS_THRESH: f32 = 0.7; #[derive(Debug)] pub struct Sam { @@ -129,7 +134,12 @@ impl Sam { img.pad_with_zeros(2, 0, IMAGE_SIZE - w) } - fn process_crop(&self, img: &Tensor, cb: CropBox, point_grids: &[(f64, f64)]) -> Result<()> { + fn process_crop( + &self, + img: &Tensor, + cb: CropBox, + point_grids: &[(f64, f64)], + ) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> { // Crop the image and calculate embeddings. let img = img.i((.., cb.y0..cb.y1, cb.x0..cb.x1))?; let img = self.preprocess(&img)?.unsqueeze(0)?; @@ -144,28 +154,86 @@ impl Sam { .iter() .map(|&(x, y)| vec![x as f32 * crop_w as f32, y as f32 * crop_h as f32]) .collect::<Vec<_>>(); + + let mut bboxes = Vec::new(); for points in points.chunks(64) { + // Run the model on this batch. let points_len = points.len(); let in_points = Tensor::new(points.to_vec(), img.device())?.unsqueeze(1)?; let in_labels = Tensor::ones((points_len, 1), DType::F32, img.device())?; let (sparse_prompt_embeddings, dense_prompt_embeddings) = self.prompt_encoder .forward(Some((&in_points, &in_labels)), None, None)?; - let (_low_res_mask, iou_predictions) = self.mask_decoder.forward( + + let (low_res_mask, iou_predictions) = self.mask_decoder.forward( &img_embeddings, &image_pe, &sparse_prompt_embeddings, &dense_prompt_embeddings, /* multimask_output */ true, )?; + let low_res_mask = low_res_mask.flatten(0, 1)?; + let iou_predictions = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?; + let dev = low_res_mask.device(); + + for (i, iou) in iou_predictions.iter().enumerate() { + // Filter by predicted IoU. + if *iou < PRED_IOU_THRESH { + continue; + } + let low_res_mask = low_res_mask.get(i)?; + + // Calculate stability score. + let bound = Tensor::new(MODEL_MASK_THRESHOLD + STABILITY_SCORE_OFFSET, dev)? + .broadcast_as(low_res_mask.shape())?; + let intersections = low_res_mask + .ge(&bound)? + .to_dtype(DType::F32)? + .sum_all()? + .to_vec0::<f32>()?; + let bound = Tensor::new(MODEL_MASK_THRESHOLD - STABILITY_SCORE_OFFSET, dev)? + .broadcast_as(low_res_mask.shape())?; + let unions = low_res_mask + .ge(&bound)? + .to_dtype(DType::F32)? + .sum_all()? + .to_vec0::<f32>()?; + let stability_score = intersections / unions; + if stability_score < STABILITY_SCORE_THRESHOLD { + continue; + } - println!("{cb:?} {iou_predictions}"); + // Threshold masks and calculate boxes. + let low_res_mask = low_res_mask + .ge(&Tensor::new(0f32, dev)?.broadcast_as(low_res_mask.shape())?)? + .to_dtype(DType::U32)?; + let low_res_mask_per_x = low_res_mask.sum(0)?.to_vec1::<u32>()?; + let low_res_mask_per_y = low_res_mask.sum(1)?.to_vec1::<u32>()?; + let min_max_x = min_max_indexes(&low_res_mask_per_x); + let min_max_y = min_max_indexes(&low_res_mask_per_y); + if let Some(((x0, x1), (y0, y1))) = min_max_x.zip(min_max_y) { + let bbox = candle_examples::object_detection::Bbox { + xmin: x0 as f32, + ymin: y0 as f32, + xmax: x1 as f32, + ymax: y1 as f32, + confidence: *iou, + data: low_res_mask, + }; + bboxes.push(bbox); + } + // TODO: + // Filter boxes that touch crop boundaries + // Compress to RLE. + } } + let mut bboxes = vec![bboxes]; // Remove duplicates within this crop. + candle_examples::object_detection::non_maximum_suppression(&mut bboxes, CROP_NMS_THRESH); - // Return to the original image frame. - Ok(()) + // TODO: Return to the original image frame. + Ok(bboxes.remove(0)) } pub fn generate_masks( @@ -175,7 +243,7 @@ impl Sam { crop_n_layer: usize, crop_overlap_ratio: f64, crop_n_points_downscale_factor: usize, - ) -> Result<()> { + ) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> { let (_c, h, w) = img.dims3()?; let point_grids = build_all_layer_point_grids( points_per_side, @@ -183,12 +251,31 @@ impl Sam { crop_n_points_downscale_factor, ); let crop_boxes = generate_crop_boxes((h, w), crop_n_layer, crop_overlap_ratio); + let mut bboxes = Vec::new(); for crop_box in crop_boxes.into_iter() { let layer_idx = crop_box.layer_idx; - self.process_crop(img, crop_box, &point_grids[layer_idx])? + let b = self.process_crop(img, crop_box, &point_grids[layer_idx])?; + bboxes.extend(b) } // TODO: remove duplicates - Ok(()) + Ok(bboxes) + } +} + +// Return the first and last indexes i for which values[i] > 0 +fn min_max_indexes(values: &[u32]) -> Option<(usize, usize)> { + let (mut min_i, mut max_i) = (usize::MAX, usize::MIN); + for (i, &s) in values.iter().enumerate() { + if s == 0 { + continue; + } + min_i = usize::min(i, min_i); + max_i = usize::max(i, max_i); + } + if max_i < min_i { + None + } else { + Some((min_i, max_i)) } } diff --git a/candle-examples/examples/segment-anything/model_transformer.rs b/candle-examples/examples/segment-anything/model_transformer.rs index e4de27cb..e12aac08 100644 --- a/candle-examples/examples/segment-anything/model_transformer.rs +++ b/candle-examples/examples/segment-anything/model_transformer.rs @@ -45,9 +45,9 @@ impl Attention { } fn forward(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> { - let q = self.q_proj.forward(q)?; - let k = self.k_proj.forward(k)?; - let v = self.v_proj.forward(v)?; + let q = self.q_proj.forward(&q.contiguous()?)?; + let k = self.k_proj.forward(&k.contiguous()?)?; + let v = self.v_proj.forward(&v.contiguous()?)?; let q = self.separate_heads(&q)?; let k = self.separate_heads(&k)?; |