diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-07 13:06:55 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-07 12:06:55 +0100 |
commit | 7b50f3e106b3d3a333e1c67f2006cbfd60c8a55a (patch) | |
tree | 1015016bea7f9cef92486f75e1cc1fd3ebcd4df2 /candle-examples/examples/segment-anything/model_mask_decoder.rs | |
parent | 8c991df3945a7c86ae86a7a52a74639ec321cef2 (diff) | |
download | candle-7b50f3e106b3d3a333e1c67f2006cbfd60c8a55a.tar.gz candle-7b50f3e106b3d3a333e1c67f2006cbfd60c8a55a.tar.bz2 candle-7b50f3e106b3d3a333e1c67f2006cbfd60c8a55a.zip |
More segment-anything again. (#764)
* More segment-anything again.
* Transformer block forward.
* Two-ways transformer.
* Position embeddings.
* Sketch the prompt encoder.
* More prompt-encoder.
* More prompt-encoder.
* Add the main sam module.
* Embed the transformer.
* And hook the transformer forward step.
* Build the model.
* Handle the global attn indexes.
* Get the model to load.
Diffstat (limited to 'candle-examples/examples/segment-anything/model_mask_decoder.rs')
-rw-r--r-- | candle-examples/examples/segment-anything/model_mask_decoder.rs | 23 |
1 files changed, 15 insertions, 8 deletions
diff --git a/candle-examples/examples/segment-anything/model_mask_decoder.rs b/candle-examples/examples/segment-anything/model_mask_decoder.rs index 55a006c4..cf3879cd 100644 --- a/candle-examples/examples/segment-anything/model_mask_decoder.rs +++ b/candle-examples/examples/segment-anything/model_mask_decoder.rs @@ -1,6 +1,8 @@ use candle::{DType, IndexOp, Result, Tensor, D}; use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; +use crate::model_transformer::TwoWayTransformer; + #[derive(Debug)] struct MlpMaskDecoder { layers: Vec<Linear>, @@ -53,7 +55,7 @@ impl Module for MlpMaskDecoder { } #[derive(Debug)] -struct MaskDecoder { +pub struct MaskDecoder { iou_token: candle_nn::Embedding, mask_tokens: candle_nn::Embedding, iou_prediction_head: MlpMaskDecoder, @@ -62,17 +64,18 @@ struct MaskDecoder { output_upscaling_conv2: candle_nn::ConvTranspose2d, num_mask_tokens: usize, output_hypernetworks_mlps: Vec<MlpMaskDecoder>, + transformer: TwoWayTransformer, } impl MaskDecoder { - fn new( + pub fn new( transformer_dim: usize, num_multimask_outputs: usize, iou_head_depth: usize, iou_head_hidden_dim: usize, vb: VarBuilder, ) -> Result<Self> { - let num_mask_tokens = num_multimask_outputs - 1; + let num_mask_tokens = num_multimask_outputs + 1; let iou_prediction_head = MlpMaskDecoder::new( transformer_dim, iou_head_hidden_dim, @@ -117,6 +120,13 @@ impl MaskDecoder { )?; output_hypernetworks_mlps.push(mlp) } + let transformer = TwoWayTransformer::new( + /* depth */ 2, + /* embedding_dim */ transformer_dim, + /* num_heads */ 8, + /* mlp_dim */ 2048, + vb.pp("transformer"), + )?; Ok(Self { iou_token, mask_tokens, @@ -126,6 +136,7 @@ impl MaskDecoder { output_upscaling_conv2, num_mask_tokens, output_hypernetworks_mlps, + transformer, }) } @@ -182,7 +193,7 @@ impl MaskDecoder { let (b, c, h, w) = src.dims4()?; // Run the transformer - let (hs, src) = run_transformer(&src, &pos_src, &tokens)?; + let (hs, src) = self.transformer.forward(&src, &pos_src, &tokens)?; let iou_token_out = hs.i((.., 0))?; let mask_tokens_out = hs.i((.., 1, 1 + self.num_mask_tokens))?; @@ -216,7 +227,3 @@ impl MaskDecoder { fn repeat_interleave(_img: &Tensor, _repeats: usize, _dim: usize) -> Result<Tensor> { todo!() } - -fn run_transformer(_src: &Tensor, _pos: &Tensor, _tokens: &Tensor) -> Result<(Tensor, Tensor)> { - todo!() -} |