summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything/model_sam.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/segment-anything/model_sam.rs')
-rw-r--r--candle-examples/examples/segment-anything/model_sam.rs411
1 files changed, 0 insertions, 411 deletions
diff --git a/candle-examples/examples/segment-anything/model_sam.rs b/candle-examples/examples/segment-anything/model_sam.rs
deleted file mode 100644
index b1a81af6..00000000
--- a/candle-examples/examples/segment-anything/model_sam.rs
+++ /dev/null
@@ -1,411 +0,0 @@
-use candle::{DType, IndexOp, Result, Tensor};
-use candle_nn::{Module, VarBuilder};
-
-use crate::model_image_encoder::ImageEncoderViT;
-use crate::model_mask_decoder::MaskDecoder;
-use crate::model_prompt_encoder::PromptEncoder;
-use crate::model_tiny_vit::{tiny_vit_5m, TinyViT};
-
-const PROMPT_EMBED_DIM: usize = 256;
-pub const IMAGE_SIZE: usize = 1024;
-const VIT_PATCH_SIZE: usize = 16;
-const PRED_IOU_THRESH: f32 = 0.88;
-const STABILITY_SCORE_OFFSET: f32 = 1.0;
-const STABILITY_SCORE_THRESHOLD: f32 = 0.95;
-const MODEL_MASK_THRESHOLD: f32 = 0.0;
-const CROP_NMS_THRESH: f32 = 0.7;
-
-#[derive(Debug)]
-enum ImageEncoder {
- Original(ImageEncoderViT),
- TinyViT(TinyViT),
-}
-
-impl Module for ImageEncoder {
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- match self {
- Self::Original(vit) => vit.forward(xs),
- Self::TinyViT(vit) => vit.forward(xs),
- }
- }
-}
-
-#[derive(Debug)]
-pub struct Sam {
- image_encoder: ImageEncoder,
- prompt_encoder: PromptEncoder,
- mask_decoder: MaskDecoder,
- pixel_mean: Tensor,
- pixel_std: Tensor,
-}
-
-impl Sam {
- pub fn new(
- encoder_embed_dim: usize,
- encoder_depth: usize,
- encoder_num_heads: usize,
- encoder_global_attn_indexes: &[usize],
- vb: VarBuilder,
- ) -> Result<Self> {
- let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE;
-
- let image_encoder = ImageEncoderViT::new(
- IMAGE_SIZE,
- VIT_PATCH_SIZE,
- 3,
- encoder_embed_dim,
- encoder_depth,
- encoder_num_heads,
- PROMPT_EMBED_DIM,
- /* qkv_bias */ true,
- /* use_rel_pos */ true,
- /* use_abs_pos */ true,
- /* window_size */ 14,
- /* global_attn_indexes */ encoder_global_attn_indexes,
- vb.pp("image_encoder"),
- )?;
- let prompt_encoder = PromptEncoder::new(
- PROMPT_EMBED_DIM,
- (image_embedding_size, image_embedding_size),
- (IMAGE_SIZE, IMAGE_SIZE),
- 16,
- vb.pp("prompt_encoder"),
- )?;
- let mask_decoder = MaskDecoder::new(
- PROMPT_EMBED_DIM,
- /* num_multitask_outputs */ 3,
- /* iou_head_depth */ 3,
- /* iou_head_hidden_dim */ 256,
- vb.pp("mask_decoder"),
- )?;
- let pixel_mean =
- Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?;
- let pixel_std =
- Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
- Ok(Self {
- image_encoder: ImageEncoder::Original(image_encoder),
- prompt_encoder,
- mask_decoder,
- pixel_std,
- pixel_mean,
- })
- }
-
- pub fn new_tiny(vb: VarBuilder) -> Result<Self> {
- let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE;
-
- let image_encoder = tiny_vit_5m(vb.pp("image_encoder"))?;
- let prompt_encoder = PromptEncoder::new(
- PROMPT_EMBED_DIM,
- (image_embedding_size, image_embedding_size),
- (IMAGE_SIZE, IMAGE_SIZE),
- 16,
- vb.pp("prompt_encoder"),
- )?;
- let mask_decoder = MaskDecoder::new(
- PROMPT_EMBED_DIM,
- /* num_multitask_outputs */ 3,
- /* iou_head_depth */ 3,
- /* iou_head_hidden_dim */ 256,
- vb.pp("mask_decoder"),
- )?;
- let pixel_mean =
- Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?;
- let pixel_std =
- Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
- Ok(Self {
- image_encoder: ImageEncoder::TinyViT(image_encoder),
- prompt_encoder,
- mask_decoder,
- pixel_std,
- pixel_mean,
- })
- }
-
- pub fn forward(
- &self,
- img: &Tensor,
- point: Option<(f64, f64)>,
- multimask_output: bool,
- ) -> Result<(Tensor, Tensor)> {
- let (_c, original_h, original_w) = img.dims3()?;
- let img = self.preprocess(img)?.unsqueeze(0)?;
- let img_embeddings = self.image_encoder.forward(&img)?;
- let image_pe = self.prompt_encoder.get_dense_pe()?;
- let points = match point {
- None => None,
- Some((x, y)) => {
- let points = Tensor::new(
- &[[[x as f32 * original_w as f32, y as f32 * original_h as f32]]],
- img.device(),
- )?;
- let labels = Tensor::ones((1, 1), DType::F32, img.device())?;
- Some((points, labels))
- }
- };
- let points = points.as_ref().map(|(x, y)| (x, y));
- let (sparse_prompt_embeddings, dense_prompt_embeddings) =
- self.prompt_encoder.forward(points, None, None)?;
- let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
- &img_embeddings,
- &image_pe,
- &sparse_prompt_embeddings,
- &dense_prompt_embeddings,
- multimask_output,
- )?;
- let mask = low_res_mask
- .upsample_nearest2d(IMAGE_SIZE, IMAGE_SIZE)?
- .get(0)?
- .i((.., ..original_h, ..original_w))?;
- Ok((mask, iou_predictions))
- }
-
- pub fn unpreprocess(&self, img: &Tensor) -> Result<Tensor> {
- let img = img
- .broadcast_mul(&self.pixel_std)?
- .broadcast_add(&self.pixel_mean)?;
- img.maximum(&img.zeros_like()?)?
- .minimum(&(img.ones_like()? * 255.)?)
- }
-
- pub fn preprocess(&self, img: &Tensor) -> Result<Tensor> {
- let (_c, h, w) = img.dims3()?;
- let img = img
- .to_dtype(DType::F32)?
- .broadcast_sub(&self.pixel_mean)?
- .broadcast_div(&self.pixel_std)?;
- if h > IMAGE_SIZE || w > IMAGE_SIZE {
- candle::bail!("image is too large ({w}, {h}), maximum size {IMAGE_SIZE}")
- }
- let img = img.pad_with_zeros(1, 0, IMAGE_SIZE - h)?;
- img.pad_with_zeros(2, 0, IMAGE_SIZE - w)
- }
-
- fn process_crop(
- &self,
- img: &Tensor,
- cb: CropBox,
- point_grids: &[(f64, f64)],
- ) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> {
- // Crop the image and calculate embeddings.
- let img = img.i((.., cb.y0..cb.y1, cb.x0..cb.x1))?;
- let img = self.preprocess(&img)?.unsqueeze(0)?;
- let img_embeddings = self.image_encoder.forward(&img)?;
-
- let crop_w = cb.x1 - cb.x0;
- let crop_h = cb.y1 - cb.y0;
-
- // Generate masks for this crop.
- let image_pe = self.prompt_encoder.get_dense_pe()?;
- let points = point_grids
- .iter()
- .map(|&(x, y)| vec![x as f32 * crop_w as f32, y as f32 * crop_h as f32])
- .collect::<Vec<_>>();
-
- let mut bboxes = Vec::new();
- for points in points.chunks(64) {
- // Run the model on this batch.
- let points_len = points.len();
- let in_points = Tensor::new(points.to_vec(), img.device())?.unsqueeze(1)?;
- let in_labels = Tensor::ones((points_len, 1), DType::F32, img.device())?;
- let (sparse_prompt_embeddings, dense_prompt_embeddings) =
- self.prompt_encoder
- .forward(Some((&in_points, &in_labels)), None, None)?;
-
- let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
- &img_embeddings,
- &image_pe,
- &sparse_prompt_embeddings,
- &dense_prompt_embeddings,
- /* multimask_output */ true,
- )?;
- let low_res_mask = low_res_mask.flatten(0, 1)?;
- let iou_predictions = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?;
- let dev = low_res_mask.device();
-
- for (i, iou) in iou_predictions.iter().enumerate() {
- // Filter by predicted IoU.
- if *iou < PRED_IOU_THRESH {
- continue;
- }
- let low_res_mask = low_res_mask.get(i)?;
-
- // Calculate stability score.
- let bound = Tensor::new(MODEL_MASK_THRESHOLD + STABILITY_SCORE_OFFSET, dev)?
- .broadcast_as(low_res_mask.shape())?;
- let intersections = low_res_mask
- .ge(&bound)?
- .to_dtype(DType::F32)?
- .sum_all()?
- .to_vec0::<f32>()?;
- let bound = Tensor::new(MODEL_MASK_THRESHOLD - STABILITY_SCORE_OFFSET, dev)?
- .broadcast_as(low_res_mask.shape())?;
- let unions = low_res_mask
- .ge(&bound)?
- .to_dtype(DType::F32)?
- .sum_all()?
- .to_vec0::<f32>()?;
- let stability_score = intersections / unions;
- if stability_score < STABILITY_SCORE_THRESHOLD {
- continue;
- }
-
- // Threshold masks and calculate boxes.
- let low_res_mask = low_res_mask
- .ge(&Tensor::new(0f32, dev)?.broadcast_as(low_res_mask.shape())?)?
- .to_dtype(DType::U32)?;
- let low_res_mask_per_x = low_res_mask.sum(0)?.to_vec1::<u32>()?;
- let low_res_mask_per_y = low_res_mask.sum(1)?.to_vec1::<u32>()?;
- let min_max_x = min_max_indexes(&low_res_mask_per_x);
- let min_max_y = min_max_indexes(&low_res_mask_per_y);
- if let Some(((x0, x1), (y0, y1))) = min_max_x.zip(min_max_y) {
- let bbox = candle_examples::object_detection::Bbox {
- xmin: x0 as f32,
- ymin: y0 as f32,
- xmax: x1 as f32,
- ymax: y1 as f32,
- confidence: *iou,
- data: low_res_mask,
- };
- bboxes.push(bbox);
- }
- // TODO:
- // Filter boxes that touch crop boundaries
- // Compress to RLE.
- }
- }
-
- let mut bboxes = vec![bboxes];
- // Remove duplicates within this crop.
- candle_examples::object_detection::non_maximum_suppression(&mut bboxes, CROP_NMS_THRESH);
-
- // TODO: Return to the original image frame.
- Ok(bboxes.remove(0))
- }
-
- pub fn generate_masks(
- &self,
- img: &Tensor,
- points_per_side: usize,
- crop_n_layer: usize,
- crop_overlap_ratio: f64,
- crop_n_points_downscale_factor: usize,
- ) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> {
- let (_c, h, w) = img.dims3()?;
- let point_grids = build_all_layer_point_grids(
- points_per_side,
- crop_n_layer,
- crop_n_points_downscale_factor,
- );
- let crop_boxes = generate_crop_boxes((h, w), crop_n_layer, crop_overlap_ratio);
- let mut bboxes = Vec::new();
- for crop_box in crop_boxes.into_iter() {
- let layer_idx = crop_box.layer_idx;
- let b = self.process_crop(img, crop_box, &point_grids[layer_idx])?;
- bboxes.extend(b)
- }
- // TODO: remove duplicates
- Ok(bboxes)
- }
-}
-
-// Return the first and last indexes i for which values[i] > 0
-fn min_max_indexes(values: &[u32]) -> Option<(usize, usize)> {
- let (mut min_i, mut max_i) = (usize::MAX, usize::MIN);
- for (i, &s) in values.iter().enumerate() {
- if s == 0 {
- continue;
- }
- min_i = usize::min(i, min_i);
- max_i = usize::max(i, max_i);
- }
- if max_i < min_i {
- None
- } else {
- Some((min_i, max_i))
- }
-}
-
-#[derive(Debug)]
-struct CropBox {
- x0: usize,
- y0: usize,
- x1: usize,
- y1: usize,
- layer_idx: usize,
-}
-
-impl CropBox {
- fn new(x0: usize, y0: usize, x1: usize, y1: usize, layer_idx: usize) -> Self {
- Self {
- x0,
- y0,
- x1,
- y1,
- layer_idx,
- }
- }
-}
-
-fn generate_crop_boxes(
- (im_h, im_w): (usize, usize),
- n_layers: usize,
- overlap_ratio: f64,
-) -> Vec<CropBox> {
- fn crop_len(orig_len: usize, n_crops: usize, overlap: usize) -> usize {
- f64::ceil((overlap * (n_crops - 1) + orig_len) as f64 / n_crops as f64) as usize
- }
-
- let short_side = usize::min(im_h, im_w);
-
- let mut crop_boxes = Vec::new();
-
- // Original image.
- crop_boxes.push(CropBox::new(0, 0, im_w, im_h, 0));
-
- for layer_idx in 1..=n_layers {
- let n_crops_per_side = 1 << layer_idx;
- let overlap = (overlap_ratio * short_side as f64 * 2. / n_crops_per_side as f64) as usize;
- let crop_w = crop_len(im_w, n_crops_per_side, overlap);
- let crop_h = crop_len(im_w, n_crops_per_side, overlap);
-
- for i_x in 0..n_crops_per_side {
- let x0 = (crop_w - overlap) * i_x;
- for i_y in 0..n_crops_per_side {
- let y0 = (crop_h - overlap) * i_y;
- let x1 = usize::min(im_w, x0 + crop_w);
- let y1 = usize::min(im_h, y0 + crop_h);
- crop_boxes.push(CropBox::new(x0, y0, x1, y1, layer_idx));
- }
- }
- }
-
- crop_boxes
-}
-
-// Generates a 2D grid of points evenly spaced in [0,1]x[0,1].
-fn build_point_grid(n_per_side: usize) -> Vec<(f64, f64)> {
- let offset = 1f64 / (2 * n_per_side) as f64;
- let mut points = Vec::with_capacity(n_per_side * n_per_side);
- for i_x in 0..n_per_side {
- let x = offset + i_x as f64 / n_per_side as f64;
- for i_y in 0..n_per_side {
- let y = offset + i_y as f64 / n_per_side as f64;
- points.push((x, y))
- }
- }
- points
-}
-
-fn build_all_layer_point_grids(
- n_per_side: usize,
- n_layers: usize,
- scale_per_layer: usize,
-) -> Vec<Vec<(f64, f64)>> {
- let mut points_by_layer = Vec::with_capacity(n_layers + 1);
- for i in 0..=n_layers {
- let n_points = n_per_side / scale_per_layer.pow(i as u32);
- points_by_layer.push(build_point_grid(n_points))
- }
- points_by_layer
-}