summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything/model_image_encoder.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/segment-anything/model_image_encoder.rs')
-rw-r--r--candle-examples/examples/segment-anything/model_image_encoder.rs25
1 files changed, 19 insertions, 6 deletions
diff --git a/candle-examples/examples/segment-anything/model_image_encoder.rs b/candle-examples/examples/segment-anything/model_image_encoder.rs
index c8b6fd7b..cfcdbb38 100644
--- a/candle-examples/examples/segment-anything/model_image_encoder.rs
+++ b/candle-examples/examples/segment-anything/model_image_encoder.rs
@@ -47,7 +47,7 @@ impl Attention {
num_heads: usize,
qkv_bias: bool,
use_rel_pos: bool,
- window_size: usize,
+ input_size: (usize, usize),
vb: VarBuilder,
) -> Result<Self> {
let qkv = crate::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
@@ -55,8 +55,8 @@ impl Attention {
let head_dim = dim / num_heads;
let scale = 1. / (head_dim as f64).sqrt();
let rel_pos_hw = if use_rel_pos {
- let h = vb.get((2 * window_size - 1, head_dim), "rel_pos_h")?;
- let w = vb.get((2 * window_size - 1, head_dim), "rel_pos_w")?;
+ let h = vb.get((2 * input_size.0 - 1, head_dim), "rel_pos_h")?;
+ let w = vb.get((2 * input_size.1 - 1, head_dim), "rel_pos_w")?;
Some((h, w))
} else {
None
@@ -114,16 +114,22 @@ impl Block {
qkv_bias: bool,
use_rel_pos: bool,
window_size: usize,
+ input_size: (usize, usize),
vb: VarBuilder,
) -> Result<Self> {
let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
+ let input_size_attn = if window_size == 0 {
+ input_size
+ } else {
+ (window_size, window_size)
+ };
let attn = Attention::new(
dim,
num_heads,
qkv_bias,
use_rel_pos,
- window_size,
+ input_size_attn,
vb.pp("attn"),
)?;
let mlp = crate::MlpBlock::new(dim, dim * 4, vb.pp("mlp"))?;
@@ -154,7 +160,7 @@ impl Module for Block {
}
#[derive(Debug)]
-struct ImageEncoderViT {
+pub struct ImageEncoderViT {
img_size: usize,
patch_embed: PatchEmbed,
blocks: Vec<Block>,
@@ -167,7 +173,7 @@ struct ImageEncoderViT {
impl ImageEncoderViT {
#[allow(clippy::too_many_arguments)]
- fn new(
+ pub fn new(
img_size: usize,
patch_size: usize,
in_chans: usize,
@@ -179,6 +185,7 @@ impl ImageEncoderViT {
use_rel_pos: bool,
use_abs_pos: bool,
window_size: usize,
+ global_attn_indexes: &[usize],
vb: VarBuilder,
) -> Result<Self> {
let patch_embed = PatchEmbed::new(
@@ -192,12 +199,18 @@ impl ImageEncoderViT {
let mut blocks = Vec::with_capacity(depth);
let vb_b = vb.pp("blocks");
for i in 0..depth {
+ let window_size = if global_attn_indexes.contains(&i) {
+ 0
+ } else {
+ window_size
+ };
let block = Block::new(
embed_dim,
num_heads,
qkv_bias,
use_rel_pos,
window_size,
+ (img_size / patch_size, img_size / patch_size),
vb_b.pp(i),
)?;
blocks.push(block)