diff options
Diffstat (limited to 'candle-examples/examples/segment-anything/model_image_encoder.rs')
-rw-r--r-- | candle-examples/examples/segment-anything/model_image_encoder.rs | 25 |
1 files changed, 19 insertions, 6 deletions
diff --git a/candle-examples/examples/segment-anything/model_image_encoder.rs b/candle-examples/examples/segment-anything/model_image_encoder.rs index c8b6fd7b..cfcdbb38 100644 --- a/candle-examples/examples/segment-anything/model_image_encoder.rs +++ b/candle-examples/examples/segment-anything/model_image_encoder.rs @@ -47,7 +47,7 @@ impl Attention { num_heads: usize, qkv_bias: bool, use_rel_pos: bool, - window_size: usize, + input_size: (usize, usize), vb: VarBuilder, ) -> Result<Self> { let qkv = crate::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?; @@ -55,8 +55,8 @@ impl Attention { let head_dim = dim / num_heads; let scale = 1. / (head_dim as f64).sqrt(); let rel_pos_hw = if use_rel_pos { - let h = vb.get((2 * window_size - 1, head_dim), "rel_pos_h")?; - let w = vb.get((2 * window_size - 1, head_dim), "rel_pos_w")?; + let h = vb.get((2 * input_size.0 - 1, head_dim), "rel_pos_h")?; + let w = vb.get((2 * input_size.1 - 1, head_dim), "rel_pos_w")?; Some((h, w)) } else { None @@ -114,16 +114,22 @@ impl Block { qkv_bias: bool, use_rel_pos: bool, window_size: usize, + input_size: (usize, usize), vb: VarBuilder, ) -> Result<Self> { let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?; let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?; + let input_size_attn = if window_size == 0 { + input_size + } else { + (window_size, window_size) + }; let attn = Attention::new( dim, num_heads, qkv_bias, use_rel_pos, - window_size, + input_size_attn, vb.pp("attn"), )?; let mlp = crate::MlpBlock::new(dim, dim * 4, vb.pp("mlp"))?; @@ -154,7 +160,7 @@ impl Module for Block { } #[derive(Debug)] -struct ImageEncoderViT { +pub struct ImageEncoderViT { img_size: usize, patch_embed: PatchEmbed, blocks: Vec<Block>, @@ -167,7 +173,7 @@ struct ImageEncoderViT { impl ImageEncoderViT { #[allow(clippy::too_many_arguments)] - fn new( + pub fn new( img_size: usize, patch_size: usize, in_chans: usize, @@ -179,6 +185,7 @@ impl ImageEncoderViT { use_rel_pos: bool, use_abs_pos: bool, window_size: usize, + global_attn_indexes: &[usize], vb: VarBuilder, ) -> Result<Self> { let patch_embed = PatchEmbed::new( @@ -192,12 +199,18 @@ impl ImageEncoderViT { let mut blocks = Vec::with_capacity(depth); let vb_b = vb.pp("blocks"); for i in 0..depth { + let window_size = if global_attn_indexes.contains(&i) { + 0 + } else { + window_size + }; let block = Block::new( embed_dim, num_heads, qkv_bias, use_rel_pos, window_size, + (img_size / patch_size, img_size / patch_size), vb_b.pp(i), )?; blocks.push(block) |