summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/segment-anything')
-rw-r--r--candle-examples/examples/segment-anything/main.rs19
-rw-r--r--candle-examples/examples/segment-anything/model_image_encoder.rs25
-rw-r--r--candle-examples/examples/segment-anything/model_mask_decoder.rs23
-rw-r--r--candle-examples/examples/segment-anything/model_prompt_encoder.rs192
-rw-r--r--candle-examples/examples/segment-anything/model_sam.rs72
-rw-r--r--candle-examples/examples/segment-anything/model_transformer.rs143
6 files changed, 454 insertions, 20 deletions
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs
index de16f70c..368b5a33 100644
--- a/candle-examples/examples/segment-anything/main.rs
+++ b/candle-examples/examples/segment-anything/main.rs
@@ -8,9 +8,11 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
-mod model_image_encoder;
-mod model_mask_decoder;
-mod model_transformer;
+pub mod model_image_encoder;
+pub mod model_mask_decoder;
+pub mod model_prompt_encoder;
+pub mod model_sam;
+pub mod model_transformer;
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
@@ -82,7 +84,7 @@ impl Module for MlpBlock {
#[derive(Parser)]
struct Args {
#[arg(long)]
- model: Option<String>,
+ model: String,
#[arg(long)]
image: String,
@@ -95,10 +97,15 @@ struct Args {
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
- let _device = candle_examples::device(args.cpu)?;
+ let device = candle_examples::device(args.cpu)?;
- let image = candle_examples::imagenet::load_image224(args.image)?;
+ let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device);
println!("loaded image {image:?}");
+ let weights = unsafe { candle::safetensors::MmapedFile::new(args.model)? };
+ let weights = weights.deserialize()?;
+ let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
+ let _sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b
+
Ok(())
}
diff --git a/candle-examples/examples/segment-anything/model_image_encoder.rs b/candle-examples/examples/segment-anything/model_image_encoder.rs
index c8b6fd7b..cfcdbb38 100644
--- a/candle-examples/examples/segment-anything/model_image_encoder.rs
+++ b/candle-examples/examples/segment-anything/model_image_encoder.rs
@@ -47,7 +47,7 @@ impl Attention {
num_heads: usize,
qkv_bias: bool,
use_rel_pos: bool,
- window_size: usize,
+ input_size: (usize, usize),
vb: VarBuilder,
) -> Result<Self> {
let qkv = crate::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
@@ -55,8 +55,8 @@ impl Attention {
let head_dim = dim / num_heads;
let scale = 1. / (head_dim as f64).sqrt();
let rel_pos_hw = if use_rel_pos {
- let h = vb.get((2 * window_size - 1, head_dim), "rel_pos_h")?;
- let w = vb.get((2 * window_size - 1, head_dim), "rel_pos_w")?;
+ let h = vb.get((2 * input_size.0 - 1, head_dim), "rel_pos_h")?;
+ let w = vb.get((2 * input_size.1 - 1, head_dim), "rel_pos_w")?;
Some((h, w))
} else {
None
@@ -114,16 +114,22 @@ impl Block {
qkv_bias: bool,
use_rel_pos: bool,
window_size: usize,
+ input_size: (usize, usize),
vb: VarBuilder,
) -> Result<Self> {
let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
+ let input_size_attn = if window_size == 0 {
+ input_size
+ } else {
+ (window_size, window_size)
+ };
let attn = Attention::new(
dim,
num_heads,
qkv_bias,
use_rel_pos,
- window_size,
+ input_size_attn,
vb.pp("attn"),
)?;
let mlp = crate::MlpBlock::new(dim, dim * 4, vb.pp("mlp"))?;
@@ -154,7 +160,7 @@ impl Module for Block {
}
#[derive(Debug)]
-struct ImageEncoderViT {
+pub struct ImageEncoderViT {
img_size: usize,
patch_embed: PatchEmbed,
blocks: Vec<Block>,
@@ -167,7 +173,7 @@ struct ImageEncoderViT {
impl ImageEncoderViT {
#[allow(clippy::too_many_arguments)]
- fn new(
+ pub fn new(
img_size: usize,
patch_size: usize,
in_chans: usize,
@@ -179,6 +185,7 @@ impl ImageEncoderViT {
use_rel_pos: bool,
use_abs_pos: bool,
window_size: usize,
+ global_attn_indexes: &[usize],
vb: VarBuilder,
) -> Result<Self> {
let patch_embed = PatchEmbed::new(
@@ -192,12 +199,18 @@ impl ImageEncoderViT {
let mut blocks = Vec::with_capacity(depth);
let vb_b = vb.pp("blocks");
for i in 0..depth {
+ let window_size = if global_attn_indexes.contains(&i) {
+ 0
+ } else {
+ window_size
+ };
let block = Block::new(
embed_dim,
num_heads,
qkv_bias,
use_rel_pos,
window_size,
+ (img_size / patch_size, img_size / patch_size),
vb_b.pp(i),
)?;
blocks.push(block)
diff --git a/candle-examples/examples/segment-anything/model_mask_decoder.rs b/candle-examples/examples/segment-anything/model_mask_decoder.rs
index 55a006c4..cf3879cd 100644
--- a/candle-examples/examples/segment-anything/model_mask_decoder.rs
+++ b/candle-examples/examples/segment-anything/model_mask_decoder.rs
@@ -1,6 +1,8 @@
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
+use crate::model_transformer::TwoWayTransformer;
+
#[derive(Debug)]
struct MlpMaskDecoder {
layers: Vec<Linear>,
@@ -53,7 +55,7 @@ impl Module for MlpMaskDecoder {
}
#[derive(Debug)]
-struct MaskDecoder {
+pub struct MaskDecoder {
iou_token: candle_nn::Embedding,
mask_tokens: candle_nn::Embedding,
iou_prediction_head: MlpMaskDecoder,
@@ -62,17 +64,18 @@ struct MaskDecoder {
output_upscaling_conv2: candle_nn::ConvTranspose2d,
num_mask_tokens: usize,
output_hypernetworks_mlps: Vec<MlpMaskDecoder>,
+ transformer: TwoWayTransformer,
}
impl MaskDecoder {
- fn new(
+ pub fn new(
transformer_dim: usize,
num_multimask_outputs: usize,
iou_head_depth: usize,
iou_head_hidden_dim: usize,
vb: VarBuilder,
) -> Result<Self> {
- let num_mask_tokens = num_multimask_outputs - 1;
+ let num_mask_tokens = num_multimask_outputs + 1;
let iou_prediction_head = MlpMaskDecoder::new(
transformer_dim,
iou_head_hidden_dim,
@@ -117,6 +120,13 @@ impl MaskDecoder {
)?;
output_hypernetworks_mlps.push(mlp)
}
+ let transformer = TwoWayTransformer::new(
+ /* depth */ 2,
+ /* embedding_dim */ transformer_dim,
+ /* num_heads */ 8,
+ /* mlp_dim */ 2048,
+ vb.pp("transformer"),
+ )?;
Ok(Self {
iou_token,
mask_tokens,
@@ -126,6 +136,7 @@ impl MaskDecoder {
output_upscaling_conv2,
num_mask_tokens,
output_hypernetworks_mlps,
+ transformer,
})
}
@@ -182,7 +193,7 @@ impl MaskDecoder {
let (b, c, h, w) = src.dims4()?;
// Run the transformer
- let (hs, src) = run_transformer(&src, &pos_src, &tokens)?;
+ let (hs, src) = self.transformer.forward(&src, &pos_src, &tokens)?;
let iou_token_out = hs.i((.., 0))?;
let mask_tokens_out = hs.i((.., 1, 1 + self.num_mask_tokens))?;
@@ -216,7 +227,3 @@ impl MaskDecoder {
fn repeat_interleave(_img: &Tensor, _repeats: usize, _dim: usize) -> Result<Tensor> {
todo!()
}
-
-fn run_transformer(_src: &Tensor, _pos: &Tensor, _tokens: &Tensor) -> Result<(Tensor, Tensor)> {
- todo!()
-}
diff --git a/candle-examples/examples/segment-anything/model_prompt_encoder.rs b/candle-examples/examples/segment-anything/model_prompt_encoder.rs
new file mode 100644
index 00000000..7ac4c66d
--- /dev/null
+++ b/candle-examples/examples/segment-anything/model_prompt_encoder.rs
@@ -0,0 +1,192 @@
+use candle::{DType, IndexOp, Result, Tensor, D};
+use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
+
+#[derive(Debug)]
+struct PostionEmbeddingRandom {
+ positional_encoding_gaussian_matrix: Tensor,
+}
+
+impl PostionEmbeddingRandom {
+ fn new(num_pos_feats: usize, vb: VarBuilder) -> Result<Self> {
+ let positional_encoding_gaussian_matrix =
+ vb.get((2, num_pos_feats), "positional_encoding_gaussian_matrix")?;
+ Ok(Self {
+ positional_encoding_gaussian_matrix,
+ })
+ }
+
+ fn pe_encoding(&self, coords: &Tensor) -> Result<Tensor> {
+ let coords = coords.affine(2., -1.)?;
+ let coords = coords.matmul(&self.positional_encoding_gaussian_matrix)?;
+ let coords = (coords * (2. * std::f64::consts::PI))?;
+ Tensor::cat(&[coords.sin()?, coords.cos()?], D::Minus1)
+ }
+
+ fn forward(&self, h: usize, w: usize) -> Result<Tensor> {
+ let device = self.positional_encoding_gaussian_matrix.device();
+ let grid = Tensor::ones((h, w), DType::F32, device)?;
+ // TODO: cumsum
+ let x_embed = (&grid - 0.5)?;
+ // TODO: cumsum
+ let y_embed = (&grid - 0.5)?;
+ let x_embed = (x_embed / w as f64)?;
+ let y_embed = (y_embed / h as f64)?;
+ let coords = Tensor::stack(&[&x_embed, &y_embed], D::Minus1)?;
+ self.pe_encoding(&coords)?.permute((2, 0, 1))
+ }
+
+ fn forward_with_coords(
+ &self,
+ coords_input: &Tensor,
+ image_size: (usize, usize),
+ ) -> Result<Tensor> {
+ let coords0 = (coords_input.narrow(D::Minus1, 0, 1)? / image_size.1 as f64)?;
+ let coords1 = (coords_input.narrow(D::Minus1, 1, 1)? / image_size.0 as f64)?;
+ let c = coords_input.dim(D::Minus1)?;
+ let coords_rest = coords_input.narrow(D::Minus1, 2, c - 2)?;
+ let coords = Tensor::cat(&[&coords0, &coords1, &coords_rest], D::Minus1)?;
+ self.pe_encoding(&coords)
+ }
+}
+
+#[derive(Debug)]
+pub struct PromptEncoder {
+ pe_layer: PostionEmbeddingRandom,
+ point_embeddings: Vec<candle_nn::Embedding>,
+ not_a_point_embed: candle_nn::Embedding,
+ mask_downscaling_conv1: candle_nn::Conv2d,
+ mask_downscaling_ln1: LayerNorm,
+ mask_downscaling_conv2: candle_nn::Conv2d,
+ mask_downscaling_ln2: LayerNorm,
+ mask_downscaling_conv3: candle_nn::Conv2d,
+ no_mask_embed: candle_nn::Embedding,
+ image_embedding_size: (usize, usize),
+ input_image_size: (usize, usize),
+}
+
+impl PromptEncoder {
+ pub fn new(
+ embed_dim: usize,
+ image_embedding_size: (usize, usize),
+ input_image_size: (usize, usize),
+ mask_in_chans: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let num_points_embeddings = 4;
+ let pe_layer = PostionEmbeddingRandom::new(embed_dim / 2, vb.pp("pe_layer"))?;
+ let not_a_point_embed = candle_nn::embedding(1, embed_dim, vb.pp("not_a_point_embed"))?;
+ let no_mask_embed = candle_nn::embedding(1, embed_dim, vb.pp("no_mask_embed"))?;
+ let cfg = candle_nn::Conv2dConfig {
+ stride: 2,
+ ..Default::default()
+ };
+ let mask_downscaling_conv1 =
+ candle_nn::conv2d(1, mask_in_chans / 4, 2, cfg, vb.pp("mask_downscaling.0"))?;
+ let mask_downscaling_conv2 = candle_nn::conv2d(
+ mask_in_chans / 4,
+ mask_in_chans,
+ 2,
+ cfg,
+ vb.pp("mask_downscaling.3"),
+ )?;
+ let mask_downscaling_conv3 = candle_nn::conv2d(
+ mask_in_chans,
+ embed_dim,
+ 1,
+ Default::default(),
+ vb.pp("mask_downscaling.6"),
+ )?;
+ let mask_downscaling_ln1 =
+ layer_norm(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?;
+ let mask_downscaling_ln2 = layer_norm(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?;
+ let mut point_embeddings = Vec::with_capacity(num_points_embeddings);
+ let vb_e = vb.pp("point_embeddings");
+ for i in 0..num_points_embeddings {
+ let emb = candle_nn::embedding(1, embed_dim, vb_e.pp(i))?;
+ point_embeddings.push(emb)
+ }
+ Ok(Self {
+ pe_layer,
+ point_embeddings,
+ not_a_point_embed,
+ mask_downscaling_conv1,
+ mask_downscaling_ln1,
+ mask_downscaling_conv2,
+ mask_downscaling_ln2,
+ mask_downscaling_conv3,
+ no_mask_embed,
+ image_embedding_size,
+ input_image_size,
+ })
+ }
+
+ fn embed_masks(&self, masks: &Tensor) -> Result<Tensor> {
+ masks
+ .apply(&self.mask_downscaling_conv1)?
+ .apply(&self.mask_downscaling_ln1)?
+ .gelu()?
+ .apply(&self.mask_downscaling_conv2)?
+ .apply(&self.mask_downscaling_ln2)?
+ .gelu()?
+ .apply(&self.mask_downscaling_conv3)
+ }
+
+ fn embed_points(&self, points: &Tensor, labels: &Tensor, pad: bool) -> Result<Tensor> {
+ let points = (points + 0.5)?;
+ let points = if pad { todo!() } else { points };
+ let point_embedding = self
+ .pe_layer
+ .forward_with_coords(&points, self.input_image_size)?;
+ // TODO: tweak based on labels.
+ Ok(point_embedding)
+ }
+
+ fn embed_boxes(&self, boxes: &Tensor) -> Result<Tensor> {
+ let boxes = (boxes + 0.5)?;
+ let coords = boxes.reshape((boxes.elem_count() / 4, 2, 2))?;
+ let corner_embedding = self
+ .pe_layer
+ .forward_with_coords(&coords, self.input_image_size)?;
+ let ce1 = corner_embedding.i((.., 0))?;
+ let ce2 = corner_embedding.i((.., 1))?;
+ let ce1 = (ce1 + self.point_embeddings[2].embeddings())?;
+ let ce2 = (ce2 + self.point_embeddings[3].embeddings())?;
+ Tensor::cat(&[&ce1, &ce2], 1)
+ }
+
+ fn forward(
+ &self,
+ points: Option<(&Tensor, &Tensor)>,
+ boxes: Option<&Tensor>,
+ masks: Option<&Tensor>,
+ ) -> Result<(Tensor, Tensor)> {
+ let se_points = match points {
+ Some((coords, labels)) => Some(self.embed_points(coords, labels, boxes.is_none())?),
+ None => None,
+ };
+ let se_boxes = match boxes {
+ Some(boxes) => Some(self.embed_boxes(boxes)?),
+ None => None,
+ };
+ let sparse_embeddings = match (se_points, se_boxes) {
+ (Some(se_points), Some(se_boxes)) => Tensor::cat(&[se_points, se_boxes], 1)?,
+ (Some(se_points), None) => se_points,
+ (None, Some(se_boxes)) => se_boxes,
+ (None, None) => Tensor::zeros(1, DType::F32, &candle::Device::Cpu)?,
+ };
+
+ let dense_embeddings = match masks {
+ None => {
+ let emb = self.no_mask_embed.embeddings();
+ emb.reshape((1, emb.elem_count(), 1, 1))?.expand((
+ 1,
+ 0,
+ self.image_embedding_size.0,
+ self.image_embedding_size.1,
+ ))?
+ }
+ Some(masks) => self.embed_masks(masks)?,
+ };
+ Ok((sparse_embeddings, dense_embeddings))
+ }
+}
diff --git a/candle-examples/examples/segment-anything/model_sam.rs b/candle-examples/examples/segment-anything/model_sam.rs
new file mode 100644
index 00000000..5a0d7e8f
--- /dev/null
+++ b/candle-examples/examples/segment-anything/model_sam.rs
@@ -0,0 +1,72 @@
+use candle::{DType, IndexOp, Result, Tensor, D};
+use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
+
+use crate::model_image_encoder::ImageEncoderViT;
+use crate::model_mask_decoder::MaskDecoder;
+use crate::model_prompt_encoder::PromptEncoder;
+
+#[derive(Debug)]
+pub struct Sam {
+ image_encoder: ImageEncoderViT,
+ prompt_encoder: PromptEncoder,
+ mask_decoder: MaskDecoder,
+ pixel_mean: Tensor,
+ pixel_std: Tensor,
+}
+
+impl Sam {
+ pub fn new(
+ encoder_embed_dim: usize,
+ encoder_depth: usize,
+ encoder_num_heads: usize,
+ encoder_global_attn_indexes: &[usize],
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ const PROMPT_EMBED_DIM: usize = 256;
+ const IMAGE_SIZE: usize = 1024;
+ const VIT_PATCH_SIZE: usize = 16;
+
+ let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE;
+
+ let image_encoder = ImageEncoderViT::new(
+ IMAGE_SIZE,
+ VIT_PATCH_SIZE,
+ 3,
+ encoder_embed_dim,
+ encoder_depth,
+ encoder_num_heads,
+ PROMPT_EMBED_DIM,
+ /* qkv_bias */ true,
+ /* use_rel_pos */ true,
+ /* use_abs_pos */ true,
+ /* window_size */ 14,
+ /* global_attn_indexes */ encoder_global_attn_indexes,
+ vb.pp("image_encoder"),
+ )?;
+ let prompt_encoder = PromptEncoder::new(
+ PROMPT_EMBED_DIM,
+ (image_embedding_size, image_embedding_size),
+ (IMAGE_SIZE, IMAGE_SIZE),
+ 16,
+ vb.pp("prompt_encoder"),
+ )?;
+ let mask_decoder = MaskDecoder::new(
+ PROMPT_EMBED_DIM,
+ /* num_multitask_outputs */ 3,
+ /* iou_head_depth */ 3,
+ /* iou_head_hidden_dim */ 256,
+ vb.pp("mask_decoder"),
+ )?;
+ let pixel_mean =
+ Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?;
+ let pixel_std =
+ Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
+ Ok(Self {
+ image_encoder,
+ prompt_encoder,
+ mask_decoder,
+ pixel_std,
+ pixel_mean,
+ })
+ }
+}
diff --git a/candle-examples/examples/segment-anything/model_transformer.rs b/candle-examples/examples/segment-anything/model_transformer.rs
index 10f7f4e5..a845085d 100644
--- a/candle-examples/examples/segment-anything/model_transformer.rs
+++ b/candle-examples/examples/segment-anything/model_transformer.rs
@@ -75,3 +75,146 @@ struct TwoWayAttentionBlock {
cross_attn_image_to_token: Attention,
skip_first_layer_pe: bool,
}
+
+impl TwoWayAttentionBlock {
+ fn new(
+ embedding_dim: usize,
+ num_heads: usize,
+ mlp_dim: usize,
+ skip_first_layer_pe: bool,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let self_attn = Attention::new(embedding_dim, num_heads, 1, vb.pp("self_attn"))?;
+ let norm1 = layer_norm(embedding_dim, 1e-5, vb.pp("norm1"))?;
+ let norm2 = layer_norm(embedding_dim, 1e-5, vb.pp("norm2"))?;
+ let norm3 = layer_norm(embedding_dim, 1e-5, vb.pp("norm3"))?;
+ let norm4 = layer_norm(embedding_dim, 1e-5, vb.pp("norm4"))?;
+ let self_attn = Attention::new(embedding_dim, num_heads, 1, vb.pp("self_attn"))?;
+ let cross_attn_token_to_image = Attention::new(
+ embedding_dim,
+ num_heads,
+ 2,
+ vb.pp("cross_attn_token_to_image"),
+ )?;
+ let cross_attn_image_to_token = Attention::new(
+ embedding_dim,
+ num_heads,
+ 2,
+ vb.pp("cross_attn_image_to_token"),
+ )?;
+ // TODO: use relu in this mlp
+ let mlp = crate::MlpBlock::new(embedding_dim, mlp_dim, vb.pp("mlp"))?;
+ Ok(Self {
+ self_attn,
+ norm1,
+ cross_attn_image_to_token,
+ norm2,
+ mlp,
+ norm3,
+ norm4,
+ cross_attn_token_to_image,
+ skip_first_layer_pe,
+ })
+ }
+
+ fn forward(
+ &self,
+ queries: &Tensor,
+ keys: &Tensor,
+ query_pe: &Tensor,
+ key_pe: &Tensor,
+ ) -> Result<(Tensor, Tensor)> {
+ // Self attention block
+ let queries = if self.skip_first_layer_pe {
+ self.self_attn.forward(queries, keys, queries)?
+ } else {
+ let q = (queries + query_pe)?;
+ let attn_out = self.self_attn.forward(&q, &q, queries)?;
+ (queries + attn_out)?
+ };
+ let queries = self.norm1.forward(&queries)?;
+
+ // Cross attention block, tokens attending to image embedding
+ let q = (&queries + query_pe)?;
+ let k = (keys + key_pe)?;
+ let attn_out = self.cross_attn_token_to_image.forward(&q, &k, keys)?;
+ let queries = (&queries + attn_out)?;
+ let queries = self.norm2.forward(&queries)?;
+
+ // MLP block
+ let mlp_out = self.mlp.forward(&queries);
+ let queries = (queries + mlp_out)?;
+ let queries = self.norm3.forward(&queries)?;
+
+ // Cross attention block, image embedding attending to tokens
+ let q = (&queries + query_pe)?;
+ let k = (keys + key_pe)?;
+ let attn_out = self.cross_attn_image_to_token.forward(&k, &q, &queries)?;
+ let keys = (keys + attn_out)?;
+ let keys = self.norm4.forward(&keys)?;
+
+ Ok((queries, keys))
+ }
+}
+
+#[derive(Debug)]
+pub struct TwoWayTransformer {
+ layers: Vec<TwoWayAttentionBlock>,
+ final_attn_token_to_image: Attention,
+ norm_final_attn: LayerNorm,
+}
+
+impl TwoWayTransformer {
+ pub fn new(
+ depth: usize,
+ embedding_dim: usize,
+ num_heads: usize,
+ mlp_dim: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let vb_l = vb.pp("layers");
+ let mut layers = Vec::with_capacity(depth);
+ for i in 0..depth {
+ let layer =
+ TwoWayAttentionBlock::new(embedding_dim, num_heads, mlp_dim, i == 0, vb_l.pp(i))?;
+ layers.push(layer)
+ }
+ let final_attn_token_to_image = Attention::new(
+ embedding_dim,
+ num_heads,
+ 2,
+ vb.pp("final_attn_token_to_image"),
+ )?;
+ let norm_final_attn = layer_norm(embedding_dim, 1e-5, vb.pp("norm_final_attn"))?;
+ Ok(Self {
+ layers,
+ final_attn_token_to_image,
+ norm_final_attn,
+ })
+ }
+
+ pub fn forward(
+ &self,
+ image_embedding: &Tensor,
+ image_pe: &Tensor,
+ point_embedding: &Tensor,
+ ) -> Result<(Tensor, Tensor)> {
+ let (bs, c, h, w) = image_embedding.dims4()?;
+ let image_embedding = image_embedding.flatten_from(2)?.permute((0, 2, 1))?;
+ let image_pe = image_pe.flatten_from(2)?.permute((0, 2, 1))?;
+
+ let mut queries = point_embedding.clone();
+ let mut keys = image_embedding;
+
+ for layer in self.layers.iter() {
+ (queries, keys) = layer.forward(&queries, &keys, point_embedding, &image_pe)?
+ }
+
+ let q = (&queries + point_embedding)?;
+ let k = (&keys + image_pe)?;
+ let attn_out = self.final_attn_token_to_image.forward(&q, &k, &keys)?;
+ let queries = (queries + attn_out)?.apply(&self.norm_final_attn)?;
+
+ Ok((queries, keys))
+ }
+}