summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/segment-anything/main.rs62
-rw-r--r--candle-examples/examples/segment-anything/model_image_encoder.rs6
-rw-r--r--candle-examples/examples/segment-anything/model_mask_decoder.rs4
-rw-r--r--candle-examples/examples/segment-anything/model_prompt_encoder.rs6
-rw-r--r--candle-examples/examples/segment-anything/model_sam.rs181
-rw-r--r--candle-examples/examples/segment-anything/model_transformer.rs6
-rw-r--r--candle-examples/src/lib.rs30
7 files changed, 251 insertions, 44 deletions
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs
index c53c1010..0f0c0482 100644
--- a/candle-examples/examples/segment-anything/main.rs
+++ b/candle-examples/examples/segment-anything/main.rs
@@ -1,6 +1,5 @@
//! SAM: Segment Anything Model
//! https://github.com/facebookresearch/segment-anything
-#![allow(unused)]
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
@@ -14,7 +13,7 @@ pub mod model_prompt_encoder;
pub mod model_sam;
pub mod model_transformer;
-use candle::{DType, IndexOp, Result, Tensor, D};
+use candle::{DType, Result, Tensor};
use candle_nn::{Linear, Module, VarBuilder};
use clap::Parser;
@@ -101,6 +100,15 @@ struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
+
+ #[arg(long)]
+ generate_masks: bool,
+
+ #[arg(long)]
+ point_x: Option<f64>,
+
+ #[arg(long)]
+ point_y: Option<f64>,
}
pub fn main() -> anyhow::Result<()> {
@@ -108,7 +116,7 @@ pub fn main() -> anyhow::Result<()> {
let device = candle_examples::device(args.cpu)?;
- let image = if args.image.ends_with(".safetensors") {
+ let (image, initial_h, initial_w) = if args.image.ends_with(".safetensors") {
let mut tensors = candle::safetensors::load(&args.image, &device)?;
let image = match tensors.remove("image") {
Some(image) => image,
@@ -119,13 +127,16 @@ pub fn main() -> anyhow::Result<()> {
tensors.into_values().next().unwrap()
}
};
- if image.rank() == 4 {
+ let image = if image.rank() == 4 {
image.get(0)?
} else {
image
- }
+ };
+ let (_c, h, w) = image.dims3()?;
+ (image, h, w)
} else {
- candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?.to_device(&device)?
+ let (image, h, w) = candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?;
+ (image.to_device(&device)?, h, w)
};
println!("loaded image {image:?}");
@@ -142,19 +153,30 @@ pub fn main() -> anyhow::Result<()> {
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
let sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b
- let (mask, iou_predictions) = sam.forward(&image, false)?;
- println!("mask:\n{mask}");
- println!("iou_predictions: {iou_predictions:?}");
-
- // Save the mask as an image.
- let mask = mask.ge(&mask.zeros_like()?)?;
- let mask = (mask * 255.)?.squeeze(0)?;
- let (_one, h, w) = mask.dims3()?;
- let mask = mask.expand((3, h, w))?;
- candle_examples::save_image(&mask, "sam_mask.png")?;
-
- let image = sam.preprocess(&image)?;
- let image = sam.unpreprocess(&image)?.to_dtype(DType::U8)?;
- candle_examples::save_image(&image, "sam_input_scaled.png")?;
+ if args.generate_masks {
+ // Default options similar to the Python version.
+ sam.generate_masks(
+ &image,
+ /* points_per_side */ 32,
+ /* crop_n_layer */ 0,
+ /* crop_overlap_ratio */ 512. / 1500.,
+ /* crop_n_points_downscale_factor */ 1,
+ )?
+ } else {
+ let point = args.point_x.zip(args.point_y);
+ let (mask, iou_predictions) = sam.forward(&image, point, false)?;
+ println!("mask:\n{mask}");
+ println!("iou_predictions: {iou_predictions:?}");
+
+ // Save the mask as an image.
+ let mask = (mask.ge(&mask.zeros_like()?)? * 255.)?;
+ let (_one, h, w) = mask.dims3()?;
+ let mask = mask.expand((3, h, w))?;
+ candle_examples::save_image_resize(&mask, "sam_mask.png", initial_h, initial_w)?;
+
+ let image = sam.preprocess(&image)?;
+ let image = sam.unpreprocess(&image)?.to_dtype(DType::U8)?;
+ candle_examples::save_image(&image, "sam_input_scaled.png")?;
+ }
Ok(())
}
diff --git a/candle-examples/examples/segment-anything/model_image_encoder.rs b/candle-examples/examples/segment-anything/model_image_encoder.rs
index 79e52d47..f1b76e23 100644
--- a/candle-examples/examples/segment-anything/model_image_encoder.rs
+++ b/candle-examples/examples/segment-anything/model_image_encoder.rs
@@ -1,4 +1,4 @@
-use candle::{DType, IndexOp, Result, Tensor, D};
+use candle::{DType, IndexOp, Result, Tensor};
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
#[derive(Debug)]
@@ -37,7 +37,6 @@ struct Attention {
proj: Linear,
num_heads: usize,
scale: f64,
- use_rel_pos: bool,
rel_pos_hw: Option<(Tensor, Tensor)>,
}
@@ -66,7 +65,6 @@ impl Attention {
proj,
num_heads,
scale,
- use_rel_pos,
rel_pos_hw,
})
}
@@ -272,7 +270,6 @@ impl Module for Block {
#[derive(Debug)]
pub struct ImageEncoderViT {
- img_size: usize,
patch_embed: PatchEmbed,
blocks: Vec<Block>,
neck_conv1: candle_nn::Conv2d,
@@ -350,7 +347,6 @@ impl ImageEncoderViT {
None
};
Ok(Self {
- img_size,
patch_embed,
blocks,
neck_conv1,
diff --git a/candle-examples/examples/segment-anything/model_mask_decoder.rs b/candle-examples/examples/segment-anything/model_mask_decoder.rs
index acbfeeea..598af1f6 100644
--- a/candle-examples/examples/segment-anything/model_mask_decoder.rs
+++ b/candle-examples/examples/segment-anything/model_mask_decoder.rs
@@ -1,4 +1,4 @@
-use candle::{DType, IndexOp, Result, Tensor, D};
+use candle::{IndexOp, Result, Tensor};
use candle_nn::{Linear, Module, VarBuilder};
use crate::model_transformer::TwoWayTransformer;
@@ -188,7 +188,7 @@ impl MaskDecoder {
// Expand per-image data in batch direction to be per mask
let src = repeat_interleave(image_embeddings, tokens.dim(0)?, 0)?;
- let src = (src + dense_prompt_embeddings)?;
+ let src = src.broadcast_add(dense_prompt_embeddings)?;
let pos_src = repeat_interleave(image_pe, tokens.dim(0)?, 0)?;
let (b, c, h, w) = src.dims4()?;
diff --git a/candle-examples/examples/segment-anything/model_prompt_encoder.rs b/candle-examples/examples/segment-anything/model_prompt_encoder.rs
index e4291ebb..b401a900 100644
--- a/candle-examples/examples/segment-anything/model_prompt_encoder.rs
+++ b/candle-examples/examples/segment-anything/model_prompt_encoder.rs
@@ -1,5 +1,5 @@
use candle::{DType, IndexOp, Result, Tensor, D};
-use candle_nn::{Linear, Module, VarBuilder};
+use candle_nn::VarBuilder;
#[derive(Debug)]
struct PostionEmbeddingRandom {
@@ -24,7 +24,6 @@ impl PostionEmbeddingRandom {
fn forward(&self, h: usize, w: usize) -> Result<Tensor> {
let device = self.positional_encoding_gaussian_matrix.device();
- let grid = Tensor::ones((h, w), DType::F32, device)?;
let x_embed = (Tensor::arange(0u32, w as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
let y_embed = (Tensor::arange(0u32, h as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
let x_embed = (x_embed / w as f64)?
@@ -157,8 +156,9 @@ impl PromptEncoder {
let point_embedding = self
.pe_layer
.forward_with_coords(&points, self.input_image_size)?;
+ let labels = labels.unsqueeze(2)?.broadcast_as(point_embedding.shape())?;
let zeros = point_embedding.zeros_like()?;
- let point_embeddings = labels.lt(&labels.zeros_like()?)?.where_cond(
+ let point_embedding = labels.lt(&labels.zeros_like()?)?.where_cond(
&self
.not_a_point_embed
.embeddings()
diff --git a/candle-examples/examples/segment-anything/model_sam.rs b/candle-examples/examples/segment-anything/model_sam.rs
index 237163a3..884559af 100644
--- a/candle-examples/examples/segment-anything/model_sam.rs
+++ b/candle-examples/examples/segment-anything/model_sam.rs
@@ -1,5 +1,5 @@
-use candle::{DType, IndexOp, Result, Tensor, D};
-use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
+use candle::{DType, IndexOp, Result, Tensor};
+use candle_nn::{Module, VarBuilder};
use crate::model_image_encoder::ImageEncoderViT;
use crate::model_mask_decoder::MaskDecoder;
@@ -70,12 +70,30 @@ impl Sam {
})
}
- pub fn forward(&self, img: &Tensor, multimask_output: bool) -> Result<(Tensor, Tensor)> {
+ pub fn forward(
+ &self,
+ img: &Tensor,
+ point: Option<(f64, f64)>,
+ multimask_output: bool,
+ ) -> Result<(Tensor, Tensor)> {
+ let (_c, original_h, original_w) = img.dims3()?;
let img = self.preprocess(img)?.unsqueeze(0)?;
let img_embeddings = self.image_encoder.forward(&img)?;
let image_pe = self.prompt_encoder.get_dense_pe()?;
+ let points = match point {
+ None => None,
+ Some((x, y)) => {
+ let points = Tensor::new(
+ &[[[x as f32 * original_w as f32, y as f32 * original_h as f32]]],
+ img.device(),
+ )?;
+ let labels = Tensor::ones((1, 1), DType::F32, img.device())?;
+ Some((points, labels))
+ }
+ };
+ let points = points.as_ref().map(|(x, y)| (x, y));
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
- self.prompt_encoder.forward(None, None, None)?;
+ self.prompt_encoder.forward(points, None, None)?;
let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
&img_embeddings,
&image_pe,
@@ -83,8 +101,11 @@ impl Sam {
&dense_prompt_embeddings,
multimask_output,
)?;
- // TODO: post-processing.
- Ok((low_res_mask, iou_predictions))
+ let mask = low_res_mask
+ .upsample_nearest2d(IMAGE_SIZE, IMAGE_SIZE)?
+ .get(0)?
+ .i((.., ..original_h, ..original_w))?;
+ Ok((mask, iou_predictions))
}
pub fn unpreprocess(&self, img: &Tensor) -> Result<Tensor> {
@@ -96,7 +117,7 @@ impl Sam {
}
pub fn preprocess(&self, img: &Tensor) -> Result<Tensor> {
- let (c, h, w) = img.dims3()?;
+ let (_c, h, w) = img.dims3()?;
let img = img
.to_dtype(DType::F32)?
.broadcast_sub(&self.pixel_mean)?
@@ -107,4 +128,150 @@ impl Sam {
let img = img.pad_with_zeros(1, 0, IMAGE_SIZE - h)?;
img.pad_with_zeros(2, 0, IMAGE_SIZE - w)
}
+
+ fn process_crop(&self, img: &Tensor, cb: CropBox, point_grids: &[(f64, f64)]) -> Result<()> {
+ // Crop the image and calculate embeddings.
+ let img = img.i((.., cb.y0..cb.y1, cb.x0..cb.x1))?;
+ let img = self.preprocess(&img)?.unsqueeze(0)?;
+ let img_embeddings = self.image_encoder.forward(&img)?;
+
+ let crop_w = cb.x1 - cb.x0;
+ let crop_h = cb.y1 - cb.y0;
+
+ // Generate masks for this crop.
+ let image_pe = self.prompt_encoder.get_dense_pe()?;
+ let points = point_grids
+ .iter()
+ .map(|&(x, y)| vec![x as f32 * crop_w as f32, y as f32 * crop_h as f32])
+ .collect::<Vec<_>>();
+ for points in points.chunks(64) {
+ let points_len = points.len();
+ let in_points = Tensor::new(points.to_vec(), img.device())?.unsqueeze(1)?;
+ let in_labels = Tensor::ones((points_len, 1), DType::F32, img.device())?;
+ let (sparse_prompt_embeddings, dense_prompt_embeddings) =
+ self.prompt_encoder
+ .forward(Some((&in_points, &in_labels)), None, None)?;
+ let (_low_res_mask, iou_predictions) = self.mask_decoder.forward(
+ &img_embeddings,
+ &image_pe,
+ &sparse_prompt_embeddings,
+ &dense_prompt_embeddings,
+ /* multimask_output */ true,
+ )?;
+
+ println!("{cb:?} {iou_predictions}");
+ }
+
+ // Remove duplicates within this crop.
+
+ // Return to the original image frame.
+ Ok(())
+ }
+
+ pub fn generate_masks(
+ &self,
+ img: &Tensor,
+ points_per_side: usize,
+ crop_n_layer: usize,
+ crop_overlap_ratio: f64,
+ crop_n_points_downscale_factor: usize,
+ ) -> Result<()> {
+ let (_c, h, w) = img.dims3()?;
+ let point_grids = build_all_layer_point_grids(
+ points_per_side,
+ crop_n_layer,
+ crop_n_points_downscale_factor,
+ );
+ let crop_boxes = generate_crop_boxes((h, w), crop_n_layer, crop_overlap_ratio);
+ for crop_box in crop_boxes.into_iter() {
+ let layer_idx = crop_box.layer_idx;
+ self.process_crop(img, crop_box, &point_grids[layer_idx])?
+ }
+ // TODO: remove duplicates
+ Ok(())
+ }
+}
+
+#[derive(Debug)]
+struct CropBox {
+ x0: usize,
+ y0: usize,
+ x1: usize,
+ y1: usize,
+ layer_idx: usize,
+}
+
+impl CropBox {
+ fn new(x0: usize, y0: usize, x1: usize, y1: usize, layer_idx: usize) -> Self {
+ Self {
+ x0,
+ y0,
+ x1,
+ y1,
+ layer_idx,
+ }
+ }
+}
+
+fn generate_crop_boxes(
+ (im_h, im_w): (usize, usize),
+ n_layers: usize,
+ overlap_ratio: f64,
+) -> Vec<CropBox> {
+ fn crop_len(orig_len: usize, n_crops: usize, overlap: usize) -> usize {
+ f64::ceil((overlap * (n_crops - 1) + orig_len) as f64 / n_crops as f64) as usize
+ }
+
+ let short_side = usize::min(im_h, im_w);
+
+ let mut crop_boxes = Vec::new();
+
+ // Original image.
+ crop_boxes.push(CropBox::new(0, 0, im_w, im_h, 0));
+
+ for layer_idx in 1..=n_layers {
+ let n_crops_per_side = 1 << layer_idx;
+ let overlap = (overlap_ratio * short_side as f64 * 2. / n_crops_per_side as f64) as usize;
+ let crop_w = crop_len(im_w, n_crops_per_side, overlap);
+ let crop_h = crop_len(im_w, n_crops_per_side, overlap);
+
+ for i_x in 0..n_crops_per_side {
+ let x0 = (crop_w - overlap) * i_x;
+ for i_y in 0..n_crops_per_side {
+ let y0 = (crop_h - overlap) * i_y;
+ let x1 = usize::min(im_w, x0 + crop_w);
+ let y1 = usize::min(im_h, y0 + crop_h);
+ crop_boxes.push(CropBox::new(x0, y0, x1, y1, layer_idx));
+ }
+ }
+ }
+
+ crop_boxes
+}
+
+// Generates a 2D grid of points evenly spaced in [0,1]x[0,1].
+fn build_point_grid(n_per_side: usize) -> Vec<(f64, f64)> {
+ let offset = 1f64 / (2 * n_per_side) as f64;
+ let mut points = Vec::with_capacity(n_per_side * n_per_side);
+ for i_x in 0..n_per_side {
+ let x = offset + i_x as f64 / n_per_side as f64;
+ for i_y in 0..n_per_side {
+ let y = offset + i_y as f64 / n_per_side as f64;
+ points.push((x, y))
+ }
+ }
+ points
+}
+
+fn build_all_layer_point_grids(
+ n_per_side: usize,
+ n_layers: usize,
+ scale_per_layer: usize,
+) -> Vec<Vec<(f64, f64)>> {
+ let mut points_by_layer = Vec::with_capacity(n_layers + 1);
+ for i in 0..=n_layers {
+ let n_points = n_per_side / scale_per_layer.pow(i as u32);
+ points_by_layer.push(build_point_grid(n_points))
+ }
+ points_by_layer
}
diff --git a/candle-examples/examples/segment-anything/model_transformer.rs b/candle-examples/examples/segment-anything/model_transformer.rs
index 044dce9b..e4de27cb 100644
--- a/candle-examples/examples/segment-anything/model_transformer.rs
+++ b/candle-examples/examples/segment-anything/model_transformer.rs
@@ -1,4 +1,4 @@
-use candle::{DType, IndexOp, Result, Tensor, D};
+use candle::{Result, Tensor};
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
#[derive(Debug)]
@@ -7,7 +7,6 @@ struct Attention {
k_proj: Linear,
v_proj: Linear,
out_proj: Linear,
- internal_dim: usize,
num_heads: usize,
}
@@ -28,7 +27,6 @@ impl Attention {
k_proj,
v_proj,
out_proj,
- internal_dim,
num_heads,
})
}
@@ -85,7 +83,6 @@ impl TwoWayAttentionBlock {
skip_first_layer_pe: bool,
vb: VarBuilder,
) -> Result<Self> {
- let self_attn = Attention::new(embedding_dim, num_heads, 1, vb.pp("self_attn"))?;
let norm1 = layer_norm(embedding_dim, 1e-5, vb.pp("norm1"))?;
let norm2 = layer_norm(embedding_dim, 1e-5, vb.pp("norm2"))?;
let norm3 = layer_norm(embedding_dim, 1e-5, vb.pp("norm3"))?;
@@ -204,7 +201,6 @@ impl TwoWayTransformer {
image_pe: &Tensor,
point_embedding: &Tensor,
) -> Result<(Tensor, Tensor)> {
- let (bs, c, h, w) = image_embedding.dims4()?;
let image_embedding = image_embedding.flatten_from(2)?.permute((0, 2, 1))?;
let image_pe = image_pe.flatten_from(2)?.permute((0, 2, 1))?;
diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs
index 66cd2f99..c14b2d6b 100644
--- a/candle-examples/src/lib.rs
+++ b/candle-examples/src/lib.rs
@@ -19,10 +19,11 @@ pub fn device(cpu: bool) -> Result<Device> {
pub fn load_image<P: AsRef<std::path::Path>>(
p: P,
resize_longest: Option<usize>,
-) -> Result<Tensor> {
+) -> Result<(Tensor, usize, usize)> {
let img = image::io::Reader::open(p)?
.decode()
.map_err(candle::Error::wrap)?;
+ let (initial_h, initial_w) = (img.height() as usize, img.width() as usize);
let img = match resize_longest {
None => img,
Some(resize_longest) => {
@@ -41,7 +42,8 @@ pub fn load_image<P: AsRef<std::path::Path>>(
let (height, width) = (img.height() as usize, img.width() as usize);
let img = img.to_rgb8();
let data = img.into_raw();
- Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))
+ let data = Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))?;
+ Ok((data, initial_h, initial_w))
}
pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
@@ -80,3 +82,27 @@ pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
image.save(p).map_err(candle::Error::wrap)?;
Ok(())
}
+
+pub fn save_image_resize<P: AsRef<std::path::Path>>(
+ img: &Tensor,
+ p: P,
+ h: usize,
+ w: usize,
+) -> Result<()> {
+ let p = p.as_ref();
+ let (channel, height, width) = img.dims3()?;
+ if channel != 3 {
+ candle::bail!("save_image expects an input of shape (3, height, width)")
+ }
+ let img = img.permute((1, 2, 0))?.flatten_all()?;
+ let pixels = img.to_vec1::<u8>()?;
+ let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
+ match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
+ Some(image) => image,
+ None => candle::bail!("error saving image {p:?}"),
+ };
+ let image = image::DynamicImage::from(image);
+ let image = image.resize_to_fill(w as u32, h as u32, image::imageops::FilterType::CatmullRom);
+ image.save(p).map_err(candle::Error::wrap)?;
+ Ok(())
+}