diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-08 12:26:56 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-08 12:26:56 +0100 |
commit | 28c87f6a34e594aca5f558bceebc4c0a9c95911a (patch) | |
tree | 11d702a507de898a7e734aa22349657d04931fb4 /candle-examples/examples/segment-anything/model_sam.rs | |
parent | c1453f00b11c9dd12c5aa81fb4355ce47d22d477 (diff) | |
download | candle-28c87f6a34e594aca5f558bceebc4c0a9c95911a.tar.gz candle-28c87f6a34e594aca5f558bceebc4c0a9c95911a.tar.bz2 candle-28c87f6a34e594aca5f558bceebc4c0a9c95911a.zip |
Automatic mask generator + point base mask (#773)
* Add more to the automatic mask generator.
* Add the target point.
* Fix.
* Remove the allow-unused.
* Mask post-processing.
Diffstat (limited to 'candle-examples/examples/segment-anything/model_sam.rs')
-rw-r--r-- | candle-examples/examples/segment-anything/model_sam.rs | 181 |
1 files changed, 174 insertions, 7 deletions
diff --git a/candle-examples/examples/segment-anything/model_sam.rs b/candle-examples/examples/segment-anything/model_sam.rs index 237163a3..884559af 100644 --- a/candle-examples/examples/segment-anything/model_sam.rs +++ b/candle-examples/examples/segment-anything/model_sam.rs @@ -1,5 +1,5 @@ -use candle::{DType, IndexOp, Result, Tensor, D}; -use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; +use candle::{DType, IndexOp, Result, Tensor}; +use candle_nn::{Module, VarBuilder}; use crate::model_image_encoder::ImageEncoderViT; use crate::model_mask_decoder::MaskDecoder; @@ -70,12 +70,30 @@ impl Sam { }) } - pub fn forward(&self, img: &Tensor, multimask_output: bool) -> Result<(Tensor, Tensor)> { + pub fn forward( + &self, + img: &Tensor, + point: Option<(f64, f64)>, + multimask_output: bool, + ) -> Result<(Tensor, Tensor)> { + let (_c, original_h, original_w) = img.dims3()?; let img = self.preprocess(img)?.unsqueeze(0)?; let img_embeddings = self.image_encoder.forward(&img)?; let image_pe = self.prompt_encoder.get_dense_pe()?; + let points = match point { + None => None, + Some((x, y)) => { + let points = Tensor::new( + &[[[x as f32 * original_w as f32, y as f32 * original_h as f32]]], + img.device(), + )?; + let labels = Tensor::ones((1, 1), DType::F32, img.device())?; + Some((points, labels)) + } + }; + let points = points.as_ref().map(|(x, y)| (x, y)); let (sparse_prompt_embeddings, dense_prompt_embeddings) = - self.prompt_encoder.forward(None, None, None)?; + self.prompt_encoder.forward(points, None, None)?; let (low_res_mask, iou_predictions) = self.mask_decoder.forward( &img_embeddings, &image_pe, @@ -83,8 +101,11 @@ impl Sam { &dense_prompt_embeddings, multimask_output, )?; - // TODO: post-processing. - Ok((low_res_mask, iou_predictions)) + let mask = low_res_mask + .upsample_nearest2d(IMAGE_SIZE, IMAGE_SIZE)? + .get(0)? + .i((.., ..original_h, ..original_w))?; + Ok((mask, iou_predictions)) } pub fn unpreprocess(&self, img: &Tensor) -> Result<Tensor> { @@ -96,7 +117,7 @@ impl Sam { } pub fn preprocess(&self, img: &Tensor) -> Result<Tensor> { - let (c, h, w) = img.dims3()?; + let (_c, h, w) = img.dims3()?; let img = img .to_dtype(DType::F32)? .broadcast_sub(&self.pixel_mean)? @@ -107,4 +128,150 @@ impl Sam { let img = img.pad_with_zeros(1, 0, IMAGE_SIZE - h)?; img.pad_with_zeros(2, 0, IMAGE_SIZE - w) } + + fn process_crop(&self, img: &Tensor, cb: CropBox, point_grids: &[(f64, f64)]) -> Result<()> { + // Crop the image and calculate embeddings. + let img = img.i((.., cb.y0..cb.y1, cb.x0..cb.x1))?; + let img = self.preprocess(&img)?.unsqueeze(0)?; + let img_embeddings = self.image_encoder.forward(&img)?; + + let crop_w = cb.x1 - cb.x0; + let crop_h = cb.y1 - cb.y0; + + // Generate masks for this crop. + let image_pe = self.prompt_encoder.get_dense_pe()?; + let points = point_grids + .iter() + .map(|&(x, y)| vec![x as f32 * crop_w as f32, y as f32 * crop_h as f32]) + .collect::<Vec<_>>(); + for points in points.chunks(64) { + let points_len = points.len(); + let in_points = Tensor::new(points.to_vec(), img.device())?.unsqueeze(1)?; + let in_labels = Tensor::ones((points_len, 1), DType::F32, img.device())?; + let (sparse_prompt_embeddings, dense_prompt_embeddings) = + self.prompt_encoder + .forward(Some((&in_points, &in_labels)), None, None)?; + let (_low_res_mask, iou_predictions) = self.mask_decoder.forward( + &img_embeddings, + &image_pe, + &sparse_prompt_embeddings, + &dense_prompt_embeddings, + /* multimask_output */ true, + )?; + + println!("{cb:?} {iou_predictions}"); + } + + // Remove duplicates within this crop. + + // Return to the original image frame. + Ok(()) + } + + pub fn generate_masks( + &self, + img: &Tensor, + points_per_side: usize, + crop_n_layer: usize, + crop_overlap_ratio: f64, + crop_n_points_downscale_factor: usize, + ) -> Result<()> { + let (_c, h, w) = img.dims3()?; + let point_grids = build_all_layer_point_grids( + points_per_side, + crop_n_layer, + crop_n_points_downscale_factor, + ); + let crop_boxes = generate_crop_boxes((h, w), crop_n_layer, crop_overlap_ratio); + for crop_box in crop_boxes.into_iter() { + let layer_idx = crop_box.layer_idx; + self.process_crop(img, crop_box, &point_grids[layer_idx])? + } + // TODO: remove duplicates + Ok(()) + } +} + +#[derive(Debug)] +struct CropBox { + x0: usize, + y0: usize, + x1: usize, + y1: usize, + layer_idx: usize, +} + +impl CropBox { + fn new(x0: usize, y0: usize, x1: usize, y1: usize, layer_idx: usize) -> Self { + Self { + x0, + y0, + x1, + y1, + layer_idx, + } + } +} + +fn generate_crop_boxes( + (im_h, im_w): (usize, usize), + n_layers: usize, + overlap_ratio: f64, +) -> Vec<CropBox> { + fn crop_len(orig_len: usize, n_crops: usize, overlap: usize) -> usize { + f64::ceil((overlap * (n_crops - 1) + orig_len) as f64 / n_crops as f64) as usize + } + + let short_side = usize::min(im_h, im_w); + + let mut crop_boxes = Vec::new(); + + // Original image. + crop_boxes.push(CropBox::new(0, 0, im_w, im_h, 0)); + + for layer_idx in 1..=n_layers { + let n_crops_per_side = 1 << layer_idx; + let overlap = (overlap_ratio * short_side as f64 * 2. / n_crops_per_side as f64) as usize; + let crop_w = crop_len(im_w, n_crops_per_side, overlap); + let crop_h = crop_len(im_w, n_crops_per_side, overlap); + + for i_x in 0..n_crops_per_side { + let x0 = (crop_w - overlap) * i_x; + for i_y in 0..n_crops_per_side { + let y0 = (crop_h - overlap) * i_y; + let x1 = usize::min(im_w, x0 + crop_w); + let y1 = usize::min(im_h, y0 + crop_h); + crop_boxes.push(CropBox::new(x0, y0, x1, y1, layer_idx)); + } + } + } + + crop_boxes +} + +// Generates a 2D grid of points evenly spaced in [0,1]x[0,1]. +fn build_point_grid(n_per_side: usize) -> Vec<(f64, f64)> { + let offset = 1f64 / (2 * n_per_side) as f64; + let mut points = Vec::with_capacity(n_per_side * n_per_side); + for i_x in 0..n_per_side { + let x = offset + i_x as f64 / n_per_side as f64; + for i_y in 0..n_per_side { + let y = offset + i_y as f64 / n_per_side as f64; + points.push((x, y)) + } + } + points +} + +fn build_all_layer_point_grids( + n_per_side: usize, + n_layers: usize, + scale_per_layer: usize, +) -> Vec<Vec<(f64, f64)>> { + let mut points_by_layer = Vec::with_capacity(n_layers + 1); + for i in 0..=n_layers { + let n_points = n_per_side / scale_per_layer.pow(i as u32); + points_by_layer.push(build_point_grid(n_points)) + } + points_by_layer } |