summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything/model_mask_decoder.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-07 13:06:55 +0200
committerGitHub <noreply@github.com>2023-09-07 12:06:55 +0100
commit7b50f3e106b3d3a333e1c67f2006cbfd60c8a55a (patch)
tree1015016bea7f9cef92486f75e1cc1fd3ebcd4df2 /candle-examples/examples/segment-anything/model_mask_decoder.rs
parent8c991df3945a7c86ae86a7a52a74639ec321cef2 (diff)
downloadcandle-7b50f3e106b3d3a333e1c67f2006cbfd60c8a55a.tar.gz
candle-7b50f3e106b3d3a333e1c67f2006cbfd60c8a55a.tar.bz2
candle-7b50f3e106b3d3a333e1c67f2006cbfd60c8a55a.zip
More segment-anything again. (#764)
* More segment-anything again. * Transformer block forward. * Two-ways transformer. * Position embeddings. * Sketch the prompt encoder. * More prompt-encoder. * More prompt-encoder. * Add the main sam module. * Embed the transformer. * And hook the transformer forward step. * Build the model. * Handle the global attn indexes. * Get the model to load.
Diffstat (limited to 'candle-examples/examples/segment-anything/model_mask_decoder.rs')
-rw-r--r--candle-examples/examples/segment-anything/model_mask_decoder.rs23
1 files changed, 15 insertions, 8 deletions
diff --git a/candle-examples/examples/segment-anything/model_mask_decoder.rs b/candle-examples/examples/segment-anything/model_mask_decoder.rs
index 55a006c4..cf3879cd 100644
--- a/candle-examples/examples/segment-anything/model_mask_decoder.rs
+++ b/candle-examples/examples/segment-anything/model_mask_decoder.rs
@@ -1,6 +1,8 @@
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
+use crate::model_transformer::TwoWayTransformer;
+
#[derive(Debug)]
struct MlpMaskDecoder {
layers: Vec<Linear>,
@@ -53,7 +55,7 @@ impl Module for MlpMaskDecoder {
}
#[derive(Debug)]
-struct MaskDecoder {
+pub struct MaskDecoder {
iou_token: candle_nn::Embedding,
mask_tokens: candle_nn::Embedding,
iou_prediction_head: MlpMaskDecoder,
@@ -62,17 +64,18 @@ struct MaskDecoder {
output_upscaling_conv2: candle_nn::ConvTranspose2d,
num_mask_tokens: usize,
output_hypernetworks_mlps: Vec<MlpMaskDecoder>,
+ transformer: TwoWayTransformer,
}
impl MaskDecoder {
- fn new(
+ pub fn new(
transformer_dim: usize,
num_multimask_outputs: usize,
iou_head_depth: usize,
iou_head_hidden_dim: usize,
vb: VarBuilder,
) -> Result<Self> {
- let num_mask_tokens = num_multimask_outputs - 1;
+ let num_mask_tokens = num_multimask_outputs + 1;
let iou_prediction_head = MlpMaskDecoder::new(
transformer_dim,
iou_head_hidden_dim,
@@ -117,6 +120,13 @@ impl MaskDecoder {
)?;
output_hypernetworks_mlps.push(mlp)
}
+ let transformer = TwoWayTransformer::new(
+ /* depth */ 2,
+ /* embedding_dim */ transformer_dim,
+ /* num_heads */ 8,
+ /* mlp_dim */ 2048,
+ vb.pp("transformer"),
+ )?;
Ok(Self {
iou_token,
mask_tokens,
@@ -126,6 +136,7 @@ impl MaskDecoder {
output_upscaling_conv2,
num_mask_tokens,
output_hypernetworks_mlps,
+ transformer,
})
}
@@ -182,7 +193,7 @@ impl MaskDecoder {
let (b, c, h, w) = src.dims4()?;
// Run the transformer
- let (hs, src) = run_transformer(&src, &pos_src, &tokens)?;
+ let (hs, src) = self.transformer.forward(&src, &pos_src, &tokens)?;
let iou_token_out = hs.i((.., 0))?;
let mask_tokens_out = hs.i((.., 1, 1 + self.num_mask_tokens))?;
@@ -216,7 +227,3 @@ impl MaskDecoder {
fn repeat_interleave(_img: &Tensor, _repeats: usize, _dim: usize) -> Result<Tensor> {
todo!()
}
-
-fn run_transformer(_src: &Tensor, _pos: &Tensor, _tokens: &Tensor) -> Result<(Tensor, Tensor)> {
- todo!()
-}