diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-07 13:06:55 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-07 12:06:55 +0100 |
commit | 7b50f3e106b3d3a333e1c67f2006cbfd60c8a55a (patch) | |
tree | 1015016bea7f9cef92486f75e1cc1fd3ebcd4df2 /candle-examples/examples/segment-anything/model_sam.rs | |
parent | 8c991df3945a7c86ae86a7a52a74639ec321cef2 (diff) | |
download | candle-7b50f3e106b3d3a333e1c67f2006cbfd60c8a55a.tar.gz candle-7b50f3e106b3d3a333e1c67f2006cbfd60c8a55a.tar.bz2 candle-7b50f3e106b3d3a333e1c67f2006cbfd60c8a55a.zip |
More segment-anything again. (#764)
* More segment-anything again.
* Transformer block forward.
* Two-ways transformer.
* Position embeddings.
* Sketch the prompt encoder.
* More prompt-encoder.
* More prompt-encoder.
* Add the main sam module.
* Embed the transformer.
* And hook the transformer forward step.
* Build the model.
* Handle the global attn indexes.
* Get the model to load.
Diffstat (limited to 'candle-examples/examples/segment-anything/model_sam.rs')
-rw-r--r-- | candle-examples/examples/segment-anything/model_sam.rs | 72 |
1 files changed, 72 insertions, 0 deletions
diff --git a/candle-examples/examples/segment-anything/model_sam.rs b/candle-examples/examples/segment-anything/model_sam.rs new file mode 100644 index 00000000..5a0d7e8f --- /dev/null +++ b/candle-examples/examples/segment-anything/model_sam.rs @@ -0,0 +1,72 @@ +use candle::{DType, IndexOp, Result, Tensor, D}; +use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; + +use crate::model_image_encoder::ImageEncoderViT; +use crate::model_mask_decoder::MaskDecoder; +use crate::model_prompt_encoder::PromptEncoder; + +#[derive(Debug)] +pub struct Sam { + image_encoder: ImageEncoderViT, + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: Tensor, + pixel_std: Tensor, +} + +impl Sam { + pub fn new( + encoder_embed_dim: usize, + encoder_depth: usize, + encoder_num_heads: usize, + encoder_global_attn_indexes: &[usize], + vb: VarBuilder, + ) -> Result<Self> { + const PROMPT_EMBED_DIM: usize = 256; + const IMAGE_SIZE: usize = 1024; + const VIT_PATCH_SIZE: usize = 16; + + let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE; + + let image_encoder = ImageEncoderViT::new( + IMAGE_SIZE, + VIT_PATCH_SIZE, + 3, + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + PROMPT_EMBED_DIM, + /* qkv_bias */ true, + /* use_rel_pos */ true, + /* use_abs_pos */ true, + /* window_size */ 14, + /* global_attn_indexes */ encoder_global_attn_indexes, + vb.pp("image_encoder"), + )?; + let prompt_encoder = PromptEncoder::new( + PROMPT_EMBED_DIM, + (image_embedding_size, image_embedding_size), + (IMAGE_SIZE, IMAGE_SIZE), + 16, + vb.pp("prompt_encoder"), + )?; + let mask_decoder = MaskDecoder::new( + PROMPT_EMBED_DIM, + /* num_multitask_outputs */ 3, + /* iou_head_depth */ 3, + /* iou_head_hidden_dim */ 256, + vb.pp("mask_decoder"), + )?; + let pixel_mean = + Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?; + let pixel_std = + Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?; + Ok(Self { + image_encoder, + prompt_encoder, + mask_decoder, + pixel_std, + pixel_mean, + }) + } +} |