diff options
Diffstat (limited to 'candle-examples/examples/segment-anything/model_sam.rs')
-rw-r--r-- | candle-examples/examples/segment-anything/model_sam.rs | 37 |
1 files changed, 33 insertions, 4 deletions
diff --git a/candle-examples/examples/segment-anything/model_sam.rs b/candle-examples/examples/segment-anything/model_sam.rs index 5a0d7e8f..1c8e9a59 100644 --- a/candle-examples/examples/segment-anything/model_sam.rs +++ b/candle-examples/examples/segment-anything/model_sam.rs @@ -5,6 +5,10 @@ use crate::model_image_encoder::ImageEncoderViT; use crate::model_mask_decoder::MaskDecoder; use crate::model_prompt_encoder::PromptEncoder; +const PROMPT_EMBED_DIM: usize = 256; +const IMAGE_SIZE: usize = 1024; +const VIT_PATCH_SIZE: usize = 16; + #[derive(Debug)] pub struct Sam { image_encoder: ImageEncoderViT, @@ -22,10 +26,6 @@ impl Sam { encoder_global_attn_indexes: &[usize], vb: VarBuilder, ) -> Result<Self> { - const PROMPT_EMBED_DIM: usize = 256; - const IMAGE_SIZE: usize = 1024; - const VIT_PATCH_SIZE: usize = 16; - let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE; let image_encoder = ImageEncoderViT::new( @@ -69,4 +69,33 @@ impl Sam { pixel_mean, }) } + + pub fn forward(&self, img: &Tensor, multimask_output: bool) -> Result<(Tensor, Tensor)> { + let img = self.preprocess(img)?.unsqueeze(0)?; + let img_embeddings = self.image_encoder.forward(&img)?; + let image_pe = self.prompt_encoder.get_dense_pe()?; + let (sparse_prompt_embeddings, dense_prompt_embeddings) = + self.prompt_encoder.forward(None, None, None)?; + let (low_res_mask, iou_predictions) = self.mask_decoder.forward( + &img_embeddings, + &image_pe, + &sparse_prompt_embeddings, + &dense_prompt_embeddings, + multimask_output, + )?; + // TODO: post-processing. + Ok((low_res_mask, iou_predictions)) + } + + fn preprocess(&self, img: &Tensor) -> Result<Tensor> { + let (c, h, w) = img.dims3()?; + let img = img + .broadcast_sub(&self.pixel_mean)? + .broadcast_div(&self.pixel_std)?; + if h > IMAGE_SIZE || w > IMAGE_SIZE { + candle::bail!("image is too large ({w}, {h}), maximum size {IMAGE_SIZE}") + } + let img = img.pad_with_zeros(1, 0, IMAGE_SIZE - h)?; + img.pad_with_zeros(2, 0, IMAGE_SIZE - w) + } } |