summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything/model_transformer.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/segment-anything/model_transformer.rs')
-rw-r--r--candle-examples/examples/segment-anything/model_transformer.rs221
1 files changed, 0 insertions, 221 deletions
diff --git a/candle-examples/examples/segment-anything/model_transformer.rs b/candle-examples/examples/segment-anything/model_transformer.rs
deleted file mode 100644
index e12aac08..00000000
--- a/candle-examples/examples/segment-anything/model_transformer.rs
+++ /dev/null
@@ -1,221 +0,0 @@
-use candle::{Result, Tensor};
-use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
-
-#[derive(Debug)]
-struct Attention {
- q_proj: Linear,
- k_proj: Linear,
- v_proj: Linear,
- out_proj: Linear,
- num_heads: usize,
-}
-
-impl Attention {
- fn new(
- embedding_dim: usize,
- num_heads: usize,
- downsample_rate: usize,
- vb: VarBuilder,
- ) -> Result<Self> {
- let internal_dim = embedding_dim / downsample_rate;
- let q_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("q_proj"))?;
- let k_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("k_proj"))?;
- let v_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("v_proj"))?;
- let out_proj = candle_nn::linear(internal_dim, embedding_dim, vb.pp("out_proj"))?;
- Ok(Self {
- q_proj,
- k_proj,
- v_proj,
- out_proj,
- num_heads,
- })
- }
-
- fn separate_heads(&self, x: &Tensor) -> Result<Tensor> {
- let (b, n, c) = x.dims3()?;
- x.reshape((b, n, self.num_heads, c / self.num_heads))?
- .transpose(1, 2)?
- .contiguous()
- }
-
- fn recombine_heads(&self, x: &Tensor) -> Result<Tensor> {
- let (b, n_heads, n_tokens, c_per_head) = x.dims4()?;
- x.transpose(1, 2)?
- .reshape((b, n_tokens, n_heads * c_per_head))
- }
-
- fn forward(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
- let q = self.q_proj.forward(&q.contiguous()?)?;
- let k = self.k_proj.forward(&k.contiguous()?)?;
- let v = self.v_proj.forward(&v.contiguous()?)?;
-
- let q = self.separate_heads(&q)?;
- let k = self.separate_heads(&k)?;
- let v = self.separate_heads(&v)?;
-
- let (_, _, _, c_per_head) = q.dims4()?;
- let attn = (q.matmul(&k.t()?)? / (c_per_head as f64).sqrt())?;
- let attn = candle_nn::ops::softmax_last_dim(&attn)?;
-
- let out = attn.matmul(&v)?;
- self.recombine_heads(&out)?.apply(&self.out_proj)
- }
-}
-
-#[derive(Debug)]
-struct TwoWayAttentionBlock {
- self_attn: Attention,
- norm1: LayerNorm,
- cross_attn_token_to_image: Attention,
- norm2: LayerNorm,
- mlp: crate::MlpBlock,
- norm3: LayerNorm,
- norm4: LayerNorm,
- cross_attn_image_to_token: Attention,
- skip_first_layer_pe: bool,
-}
-
-impl TwoWayAttentionBlock {
- fn new(
- embedding_dim: usize,
- num_heads: usize,
- mlp_dim: usize,
- skip_first_layer_pe: bool,
- vb: VarBuilder,
- ) -> Result<Self> {
- let norm1 = layer_norm(embedding_dim, 1e-5, vb.pp("norm1"))?;
- let norm2 = layer_norm(embedding_dim, 1e-5, vb.pp("norm2"))?;
- let norm3 = layer_norm(embedding_dim, 1e-5, vb.pp("norm3"))?;
- let norm4 = layer_norm(embedding_dim, 1e-5, vb.pp("norm4"))?;
- let self_attn = Attention::new(embedding_dim, num_heads, 1, vb.pp("self_attn"))?;
- let cross_attn_token_to_image = Attention::new(
- embedding_dim,
- num_heads,
- 2,
- vb.pp("cross_attn_token_to_image"),
- )?;
- let cross_attn_image_to_token = Attention::new(
- embedding_dim,
- num_heads,
- 2,
- vb.pp("cross_attn_image_to_token"),
- )?;
- let mlp = crate::MlpBlock::new(
- embedding_dim,
- mlp_dim,
- candle_nn::Activation::Relu,
- vb.pp("mlp"),
- )?;
- Ok(Self {
- self_attn,
- norm1,
- cross_attn_image_to_token,
- norm2,
- mlp,
- norm3,
- norm4,
- cross_attn_token_to_image,
- skip_first_layer_pe,
- })
- }
-
- fn forward(
- &self,
- queries: &Tensor,
- keys: &Tensor,
- query_pe: &Tensor,
- key_pe: &Tensor,
- ) -> Result<(Tensor, Tensor)> {
- // Self attention block
- let queries = if self.skip_first_layer_pe {
- self.self_attn.forward(queries, queries, queries)?
- } else {
- let q = (queries + query_pe)?;
- let attn_out = self.self_attn.forward(&q, &q, queries)?;
- (queries + attn_out)?
- };
- let queries = self.norm1.forward(&queries)?;
-
- // Cross attention block, tokens attending to image embedding
- let q = (&queries + query_pe)?;
- let k = (keys + key_pe)?;
- let attn_out = self.cross_attn_token_to_image.forward(&q, &k, keys)?;
- let queries = (&queries + attn_out)?;
- let queries = self.norm2.forward(&queries)?;
-
- // MLP block
- let mlp_out = self.mlp.forward(&queries);
- let queries = (queries + mlp_out)?;
- let queries = self.norm3.forward(&queries)?;
-
- // Cross attention block, image embedding attending to tokens
- let q = (&queries + query_pe)?;
- let k = (keys + key_pe)?;
- let attn_out = self.cross_attn_image_to_token.forward(&k, &q, &queries)?;
- let keys = (keys + attn_out)?;
- let keys = self.norm4.forward(&keys)?;
-
- Ok((queries, keys))
- }
-}
-
-#[derive(Debug)]
-pub struct TwoWayTransformer {
- layers: Vec<TwoWayAttentionBlock>,
- final_attn_token_to_image: Attention,
- norm_final_attn: LayerNorm,
-}
-
-impl TwoWayTransformer {
- pub fn new(
- depth: usize,
- embedding_dim: usize,
- num_heads: usize,
- mlp_dim: usize,
- vb: VarBuilder,
- ) -> Result<Self> {
- let vb_l = vb.pp("layers");
- let mut layers = Vec::with_capacity(depth);
- for i in 0..depth {
- let layer =
- TwoWayAttentionBlock::new(embedding_dim, num_heads, mlp_dim, i == 0, vb_l.pp(i))?;
- layers.push(layer)
- }
- let final_attn_token_to_image = Attention::new(
- embedding_dim,
- num_heads,
- 2,
- vb.pp("final_attn_token_to_image"),
- )?;
- let norm_final_attn = layer_norm(embedding_dim, 1e-5, vb.pp("norm_final_attn"))?;
- Ok(Self {
- layers,
- final_attn_token_to_image,
- norm_final_attn,
- })
- }
-
- pub fn forward(
- &self,
- image_embedding: &Tensor,
- image_pe: &Tensor,
- point_embedding: &Tensor,
- ) -> Result<(Tensor, Tensor)> {
- let image_embedding = image_embedding.flatten_from(2)?.permute((0, 2, 1))?;
- let image_pe = image_pe.flatten_from(2)?.permute((0, 2, 1))?;
-
- let mut queries = point_embedding.clone();
- let mut keys = image_embedding;
-
- for layer in self.layers.iter() {
- (queries, keys) = layer.forward(&queries, &keys, point_embedding, &image_pe)?
- }
-
- let q = (&queries + point_embedding)?;
- let k = (&keys + image_pe)?;
- let attn_out = self.final_attn_token_to_image.forward(&q, &k, &keys)?;
- let queries = (queries + attn_out)?.apply(&self.norm_final_attn)?;
-
- Ok((queries, keys))
- }
-}