diff options
Diffstat (limited to 'candle-examples/examples/segment-anything/model_sam.rs')
-rw-r--r-- | candle-examples/examples/segment-anything/model_sam.rs | 72 |
1 files changed, 72 insertions, 0 deletions
diff --git a/candle-examples/examples/segment-anything/model_sam.rs b/candle-examples/examples/segment-anything/model_sam.rs new file mode 100644 index 00000000..5a0d7e8f --- /dev/null +++ b/candle-examples/examples/segment-anything/model_sam.rs @@ -0,0 +1,72 @@ +use candle::{DType, IndexOp, Result, Tensor, D}; +use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; + +use crate::model_image_encoder::ImageEncoderViT; +use crate::model_mask_decoder::MaskDecoder; +use crate::model_prompt_encoder::PromptEncoder; + +#[derive(Debug)] +pub struct Sam { + image_encoder: ImageEncoderViT, + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: Tensor, + pixel_std: Tensor, +} + +impl Sam { + pub fn new( + encoder_embed_dim: usize, + encoder_depth: usize, + encoder_num_heads: usize, + encoder_global_attn_indexes: &[usize], + vb: VarBuilder, + ) -> Result<Self> { + const PROMPT_EMBED_DIM: usize = 256; + const IMAGE_SIZE: usize = 1024; + const VIT_PATCH_SIZE: usize = 16; + + let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE; + + let image_encoder = ImageEncoderViT::new( + IMAGE_SIZE, + VIT_PATCH_SIZE, + 3, + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + PROMPT_EMBED_DIM, + /* qkv_bias */ true, + /* use_rel_pos */ true, + /* use_abs_pos */ true, + /* window_size */ 14, + /* global_attn_indexes */ encoder_global_attn_indexes, + vb.pp("image_encoder"), + )?; + let prompt_encoder = PromptEncoder::new( + PROMPT_EMBED_DIM, + (image_embedding_size, image_embedding_size), + (IMAGE_SIZE, IMAGE_SIZE), + 16, + vb.pp("prompt_encoder"), + )?; + let mask_decoder = MaskDecoder::new( + PROMPT_EMBED_DIM, + /* num_multitask_outputs */ 3, + /* iou_head_depth */ 3, + /* iou_head_hidden_dim */ 256, + vb.pp("mask_decoder"), + )?; + let pixel_mean = + Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?; + let pixel_std = + Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?; + Ok(Self { + image_encoder, + prompt_encoder, + mask_decoder, + pixel_std, + pixel_mean, + }) + } +} |