summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything/model_transformer.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/segment-anything/model_transformer.rs')
-rw-r--r--candle-examples/examples/segment-anything/model_transformer.rs143
1 files changed, 143 insertions, 0 deletions
diff --git a/candle-examples/examples/segment-anything/model_transformer.rs b/candle-examples/examples/segment-anything/model_transformer.rs
index 10f7f4e5..a845085d 100644
--- a/candle-examples/examples/segment-anything/model_transformer.rs
+++ b/candle-examples/examples/segment-anything/model_transformer.rs
@@ -75,3 +75,146 @@ struct TwoWayAttentionBlock {
cross_attn_image_to_token: Attention,
skip_first_layer_pe: bool,
}
+
+impl TwoWayAttentionBlock {
+ fn new(
+ embedding_dim: usize,
+ num_heads: usize,
+ mlp_dim: usize,
+ skip_first_layer_pe: bool,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let self_attn = Attention::new(embedding_dim, num_heads, 1, vb.pp("self_attn"))?;
+ let norm1 = layer_norm(embedding_dim, 1e-5, vb.pp("norm1"))?;
+ let norm2 = layer_norm(embedding_dim, 1e-5, vb.pp("norm2"))?;
+ let norm3 = layer_norm(embedding_dim, 1e-5, vb.pp("norm3"))?;
+ let norm4 = layer_norm(embedding_dim, 1e-5, vb.pp("norm4"))?;
+ let self_attn = Attention::new(embedding_dim, num_heads, 1, vb.pp("self_attn"))?;
+ let cross_attn_token_to_image = Attention::new(
+ embedding_dim,
+ num_heads,
+ 2,
+ vb.pp("cross_attn_token_to_image"),
+ )?;
+ let cross_attn_image_to_token = Attention::new(
+ embedding_dim,
+ num_heads,
+ 2,
+ vb.pp("cross_attn_image_to_token"),
+ )?;
+ // TODO: use relu in this mlp
+ let mlp = crate::MlpBlock::new(embedding_dim, mlp_dim, vb.pp("mlp"))?;
+ Ok(Self {
+ self_attn,
+ norm1,
+ cross_attn_image_to_token,
+ norm2,
+ mlp,
+ norm3,
+ norm4,
+ cross_attn_token_to_image,
+ skip_first_layer_pe,
+ })
+ }
+
+ fn forward(
+ &self,
+ queries: &Tensor,
+ keys: &Tensor,
+ query_pe: &Tensor,
+ key_pe: &Tensor,
+ ) -> Result<(Tensor, Tensor)> {
+ // Self attention block
+ let queries = if self.skip_first_layer_pe {
+ self.self_attn.forward(queries, keys, queries)?
+ } else {
+ let q = (queries + query_pe)?;
+ let attn_out = self.self_attn.forward(&q, &q, queries)?;
+ (queries + attn_out)?
+ };
+ let queries = self.norm1.forward(&queries)?;
+
+ // Cross attention block, tokens attending to image embedding
+ let q = (&queries + query_pe)?;
+ let k = (keys + key_pe)?;
+ let attn_out = self.cross_attn_token_to_image.forward(&q, &k, keys)?;
+ let queries = (&queries + attn_out)?;
+ let queries = self.norm2.forward(&queries)?;
+
+ // MLP block
+ let mlp_out = self.mlp.forward(&queries);
+ let queries = (queries + mlp_out)?;
+ let queries = self.norm3.forward(&queries)?;
+
+ // Cross attention block, image embedding attending to tokens
+ let q = (&queries + query_pe)?;
+ let k = (keys + key_pe)?;
+ let attn_out = self.cross_attn_image_to_token.forward(&k, &q, &queries)?;
+ let keys = (keys + attn_out)?;
+ let keys = self.norm4.forward(&keys)?;
+
+ Ok((queries, keys))
+ }
+}
+
+#[derive(Debug)]
+pub struct TwoWayTransformer {
+ layers: Vec<TwoWayAttentionBlock>,
+ final_attn_token_to_image: Attention,
+ norm_final_attn: LayerNorm,
+}
+
+impl TwoWayTransformer {
+ pub fn new(
+ depth: usize,
+ embedding_dim: usize,
+ num_heads: usize,
+ mlp_dim: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let vb_l = vb.pp("layers");
+ let mut layers = Vec::with_capacity(depth);
+ for i in 0..depth {
+ let layer =
+ TwoWayAttentionBlock::new(embedding_dim, num_heads, mlp_dim, i == 0, vb_l.pp(i))?;
+ layers.push(layer)
+ }
+ let final_attn_token_to_image = Attention::new(
+ embedding_dim,
+ num_heads,
+ 2,
+ vb.pp("final_attn_token_to_image"),
+ )?;
+ let norm_final_attn = layer_norm(embedding_dim, 1e-5, vb.pp("norm_final_attn"))?;
+ Ok(Self {
+ layers,
+ final_attn_token_to_image,
+ norm_final_attn,
+ })
+ }
+
+ pub fn forward(
+ &self,
+ image_embedding: &Tensor,
+ image_pe: &Tensor,
+ point_embedding: &Tensor,
+ ) -> Result<(Tensor, Tensor)> {
+ let (bs, c, h, w) = image_embedding.dims4()?;
+ let image_embedding = image_embedding.flatten_from(2)?.permute((0, 2, 1))?;
+ let image_pe = image_pe.flatten_from(2)?.permute((0, 2, 1))?;
+
+ let mut queries = point_embedding.clone();
+ let mut keys = image_embedding;
+
+ for layer in self.layers.iter() {
+ (queries, keys) = layer.forward(&queries, &keys, point_embedding, &image_pe)?
+ }
+
+ let q = (&queries + point_embedding)?;
+ let k = (&keys + image_pe)?;
+ let attn_out = self.final_attn_token_to_image.forward(&q, &k, &keys)?;
+ let queries = (queries + attn_out)?.apply(&self.norm_final_attn)?;
+
+ Ok((queries, keys))
+ }
+}