summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything/model_sam.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-07 19:22:45 +0100
committerGitHub <noreply@github.com>2023-09-07 19:22:45 +0100
commit7396b8ed1a5394c58fcc772e5f6e6038577505b8 (patch)
treef7ce0cf676705e800093c05884ce5fc7443b7b0b /candle-examples/examples/segment-anything/model_sam.rs
parent7b50f3e106b3d3a333e1c67f2006cbfd60c8a55a (diff)
downloadcandle-7396b8ed1a5394c58fcc772e5f6e6038577505b8.tar.gz
candle-7396b8ed1a5394c58fcc772e5f6e6038577505b8.tar.bz2
candle-7396b8ed1a5394c58fcc772e5f6e6038577505b8.zip
Segment Anything - process images (#766)
* Start processing images. * Add LayerNorm2d. * Properly use LayerNorm2d. * Tweak eps. * Use LayerNorm on inputs with a rank different from 3. * Window partitioning. * Fix a couple todos. * More todos. * Hard-code the einsums. * More padding support. * Some sizes tweaks. * Use the hub to get the weights. * Use a batch matmul. * Tweaks. * More fixes. * Get some predictions to be generated.
Diffstat (limited to 'candle-examples/examples/segment-anything/model_sam.rs')
-rw-r--r--candle-examples/examples/segment-anything/model_sam.rs37
1 files changed, 33 insertions, 4 deletions
diff --git a/candle-examples/examples/segment-anything/model_sam.rs b/candle-examples/examples/segment-anything/model_sam.rs
index 5a0d7e8f..1c8e9a59 100644
--- a/candle-examples/examples/segment-anything/model_sam.rs
+++ b/candle-examples/examples/segment-anything/model_sam.rs
@@ -5,6 +5,10 @@ use crate::model_image_encoder::ImageEncoderViT;
use crate::model_mask_decoder::MaskDecoder;
use crate::model_prompt_encoder::PromptEncoder;
+const PROMPT_EMBED_DIM: usize = 256;
+const IMAGE_SIZE: usize = 1024;
+const VIT_PATCH_SIZE: usize = 16;
+
#[derive(Debug)]
pub struct Sam {
image_encoder: ImageEncoderViT,
@@ -22,10 +26,6 @@ impl Sam {
encoder_global_attn_indexes: &[usize],
vb: VarBuilder,
) -> Result<Self> {
- const PROMPT_EMBED_DIM: usize = 256;
- const IMAGE_SIZE: usize = 1024;
- const VIT_PATCH_SIZE: usize = 16;
-
let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE;
let image_encoder = ImageEncoderViT::new(
@@ -69,4 +69,33 @@ impl Sam {
pixel_mean,
})
}
+
+ pub fn forward(&self, img: &Tensor, multimask_output: bool) -> Result<(Tensor, Tensor)> {
+ let img = self.preprocess(img)?.unsqueeze(0)?;
+ let img_embeddings = self.image_encoder.forward(&img)?;
+ let image_pe = self.prompt_encoder.get_dense_pe()?;
+ let (sparse_prompt_embeddings, dense_prompt_embeddings) =
+ self.prompt_encoder.forward(None, None, None)?;
+ let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
+ &img_embeddings,
+ &image_pe,
+ &sparse_prompt_embeddings,
+ &dense_prompt_embeddings,
+ multimask_output,
+ )?;
+ // TODO: post-processing.
+ Ok((low_res_mask, iou_predictions))
+ }
+
+ fn preprocess(&self, img: &Tensor) -> Result<Tensor> {
+ let (c, h, w) = img.dims3()?;
+ let img = img
+ .broadcast_sub(&self.pixel_mean)?
+ .broadcast_div(&self.pixel_std)?;
+ if h > IMAGE_SIZE || w > IMAGE_SIZE {
+ candle::bail!("image is too large ({w}, {h}), maximum size {IMAGE_SIZE}")
+ }
+ let img = img.pad_with_zeros(1, 0, IMAGE_SIZE - h)?;
+ img.pad_with_zeros(2, 0, IMAGE_SIZE - w)
+ }
}