diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-07 19:22:45 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-07 19:22:45 +0100 |
commit | 7396b8ed1a5394c58fcc772e5f6e6038577505b8 (patch) | |
tree | f7ce0cf676705e800093c05884ce5fc7443b7b0b /candle-examples/examples/segment-anything/model_sam.rs | |
parent | 7b50f3e106b3d3a333e1c67f2006cbfd60c8a55a (diff) | |
download | candle-7396b8ed1a5394c58fcc772e5f6e6038577505b8.tar.gz candle-7396b8ed1a5394c58fcc772e5f6e6038577505b8.tar.bz2 candle-7396b8ed1a5394c58fcc772e5f6e6038577505b8.zip |
Segment Anything - process images (#766)
* Start processing images.
* Add LayerNorm2d.
* Properly use LayerNorm2d.
* Tweak eps.
* Use LayerNorm on inputs with a rank different from 3.
* Window partitioning.
* Fix a couple todos.
* More todos.
* Hard-code the einsums.
* More padding support.
* Some sizes tweaks.
* Use the hub to get the weights.
* Use a batch matmul.
* Tweaks.
* More fixes.
* Get some predictions to be generated.
Diffstat (limited to 'candle-examples/examples/segment-anything/model_sam.rs')
-rw-r--r-- | candle-examples/examples/segment-anything/model_sam.rs | 37 |
1 files changed, 33 insertions, 4 deletions
diff --git a/candle-examples/examples/segment-anything/model_sam.rs b/candle-examples/examples/segment-anything/model_sam.rs index 5a0d7e8f..1c8e9a59 100644 --- a/candle-examples/examples/segment-anything/model_sam.rs +++ b/candle-examples/examples/segment-anything/model_sam.rs @@ -5,6 +5,10 @@ use crate::model_image_encoder::ImageEncoderViT; use crate::model_mask_decoder::MaskDecoder; use crate::model_prompt_encoder::PromptEncoder; +const PROMPT_EMBED_DIM: usize = 256; +const IMAGE_SIZE: usize = 1024; +const VIT_PATCH_SIZE: usize = 16; + #[derive(Debug)] pub struct Sam { image_encoder: ImageEncoderViT, @@ -22,10 +26,6 @@ impl Sam { encoder_global_attn_indexes: &[usize], vb: VarBuilder, ) -> Result<Self> { - const PROMPT_EMBED_DIM: usize = 256; - const IMAGE_SIZE: usize = 1024; - const VIT_PATCH_SIZE: usize = 16; - let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE; let image_encoder = ImageEncoderViT::new( @@ -69,4 +69,33 @@ impl Sam { pixel_mean, }) } + + pub fn forward(&self, img: &Tensor, multimask_output: bool) -> Result<(Tensor, Tensor)> { + let img = self.preprocess(img)?.unsqueeze(0)?; + let img_embeddings = self.image_encoder.forward(&img)?; + let image_pe = self.prompt_encoder.get_dense_pe()?; + let (sparse_prompt_embeddings, dense_prompt_embeddings) = + self.prompt_encoder.forward(None, None, None)?; + let (low_res_mask, iou_predictions) = self.mask_decoder.forward( + &img_embeddings, + &image_pe, + &sparse_prompt_embeddings, + &dense_prompt_embeddings, + multimask_output, + )?; + // TODO: post-processing. + Ok((low_res_mask, iou_predictions)) + } + + fn preprocess(&self, img: &Tensor) -> Result<Tensor> { + let (c, h, w) = img.dims3()?; + let img = img + .broadcast_sub(&self.pixel_mean)? + .broadcast_div(&self.pixel_std)?; + if h > IMAGE_SIZE || w > IMAGE_SIZE { + candle::bail!("image is too large ({w}, {h}), maximum size {IMAGE_SIZE}") + } + let img = img.pad_with_zeros(1, 0, IMAGE_SIZE - h)?; + img.pad_with_zeros(2, 0, IMAGE_SIZE - w) + } } |