summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything/model_mask_decoder.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/segment-anything/model_mask_decoder.rs')
-rw-r--r--candle-examples/examples/segment-anything/model_mask_decoder.rs222
1 files changed, 222 insertions, 0 deletions
diff --git a/candle-examples/examples/segment-anything/model_mask_decoder.rs b/candle-examples/examples/segment-anything/model_mask_decoder.rs
new file mode 100644
index 00000000..55a006c4
--- /dev/null
+++ b/candle-examples/examples/segment-anything/model_mask_decoder.rs
@@ -0,0 +1,222 @@
+use candle::{DType, IndexOp, Result, Tensor, D};
+use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
+
+#[derive(Debug)]
+struct MlpMaskDecoder {
+ layers: Vec<Linear>,
+ sigmoid_output: bool,
+}
+
+impl MlpMaskDecoder {
+ fn new(
+ input_dim: usize,
+ hidden_dim: usize,
+ output_dim: usize,
+ num_layers: usize,
+ sigmoid_output: bool,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let mut layers = Vec::with_capacity(num_layers);
+ let vb = vb.pp("layers");
+ for i in 0..num_layers {
+ let in_dim = if i == 0 { input_dim } else { hidden_dim };
+ let out_dim = if i + 1 == num_layers {
+ output_dim
+ } else {
+ hidden_dim
+ };
+ let layer = crate::linear(vb.pp(i), in_dim, out_dim, true)?;
+ layers.push(layer)
+ }
+ Ok(Self {
+ layers,
+ sigmoid_output,
+ })
+ }
+}
+
+impl Module for MlpMaskDecoder {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut xs = xs.clone();
+ for (i, layer) in self.layers.iter().enumerate() {
+ xs = layer.forward(&xs)?;
+ if i + 1 < self.layers.len() {
+ xs = xs.relu()?
+ }
+ }
+ if self.sigmoid_output {
+ candle_nn::ops::sigmoid(&xs)
+ } else {
+ Ok(xs)
+ }
+ }
+}
+
+#[derive(Debug)]
+struct MaskDecoder {
+ iou_token: candle_nn::Embedding,
+ mask_tokens: candle_nn::Embedding,
+ iou_prediction_head: MlpMaskDecoder,
+ output_upscaling_conv1: candle_nn::ConvTranspose2d,
+ output_upscaling_ln: LayerNorm,
+ output_upscaling_conv2: candle_nn::ConvTranspose2d,
+ num_mask_tokens: usize,
+ output_hypernetworks_mlps: Vec<MlpMaskDecoder>,
+}
+
+impl MaskDecoder {
+ fn new(
+ transformer_dim: usize,
+ num_multimask_outputs: usize,
+ iou_head_depth: usize,
+ iou_head_hidden_dim: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let num_mask_tokens = num_multimask_outputs - 1;
+ let iou_prediction_head = MlpMaskDecoder::new(
+ transformer_dim,
+ iou_head_hidden_dim,
+ num_mask_tokens,
+ iou_head_depth,
+ false,
+ vb.pp("iou_prediction_head"),
+ )?;
+ let iou_token = candle_nn::embedding(1, transformer_dim, vb.pp("iou_token"))?;
+ let mask_tokens =
+ candle_nn::embedding(num_mask_tokens, transformer_dim, vb.pp("mask_tokens"))?;
+ let cfg = candle_nn::ConvTranspose2dConfig {
+ stride: 2,
+ ..Default::default()
+ };
+ let output_upscaling_conv1 = candle_nn::conv_transpose2d(
+ transformer_dim,
+ transformer_dim / 4,
+ 2,
+ cfg,
+ vb.pp("output_upscaling.0"),
+ )?;
+ let output_upscaling_ln =
+ layer_norm(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?;
+ let output_upscaling_conv2 = candle_nn::conv_transpose2d(
+ transformer_dim / 4,
+ transformer_dim / 8,
+ 2,
+ cfg,
+ vb.pp("output_upscaling.3"),
+ )?;
+ let mut output_hypernetworks_mlps = Vec::with_capacity(num_mask_tokens);
+ let vb_o = vb.pp("output_hypernetworks_mlps");
+ for i in 0..num_mask_tokens {
+ let mlp = MlpMaskDecoder::new(
+ transformer_dim,
+ transformer_dim,
+ transformer_dim / 8,
+ 3,
+ false,
+ vb_o.pp(i),
+ )?;
+ output_hypernetworks_mlps.push(mlp)
+ }
+ Ok(Self {
+ iou_token,
+ mask_tokens,
+ iou_prediction_head,
+ output_upscaling_conv1,
+ output_upscaling_ln,
+ output_upscaling_conv2,
+ num_mask_tokens,
+ output_hypernetworks_mlps,
+ })
+ }
+
+ fn forward(
+ &self,
+ image_embeddings: &Tensor,
+ image_pe: &Tensor,
+ sparse_prompt_embeddings: &Tensor,
+ dense_prompt_embeddings: &Tensor,
+ multimask_output: bool,
+ ) -> Result<(Tensor, Tensor)> {
+ let (masks, iou_pred) = self.predict_masks(
+ image_embeddings,
+ image_pe,
+ sparse_prompt_embeddings,
+ dense_prompt_embeddings,
+ )?;
+ let masks = if multimask_output {
+ masks.i((.., 1..))?
+ } else {
+ masks.i((.., 0..1))?
+ };
+ let iou_pred = if multimask_output {
+ iou_pred.i((.., 1..))?
+ } else {
+ iou_pred.i((.., 0..1))?
+ };
+ Ok((masks, iou_pred))
+ }
+
+ fn predict_masks(
+ &self,
+ image_embeddings: &Tensor,
+ image_pe: &Tensor,
+ sparse_prompt_embeddings: &Tensor,
+ dense_prompt_embeddings: &Tensor,
+ ) -> Result<(Tensor, Tensor)> {
+ // Concatenate ouput tokens.
+ let output_tokens = Tensor::cat(
+ &[self.iou_token.embeddings(), self.mask_tokens.embeddings()],
+ 0,
+ )?;
+ let (d1, d2) = output_tokens.dims2()?;
+ let output_tokens =
+ output_tokens
+ .unsqueeze(0)?
+ .expand((sparse_prompt_embeddings.dim(0)?, d1, d2))?;
+ let tokens = Tensor::cat(&[&output_tokens, sparse_prompt_embeddings], 1)?;
+
+ // Expand per-image data in batch direction to be per mask
+ let src = repeat_interleave(image_embeddings, tokens.dim(0)?, 0)?;
+ let src = (src + dense_prompt_embeddings)?;
+ let pos_src = repeat_interleave(image_pe, tokens.dim(0)?, 0)?;
+ let (b, c, h, w) = src.dims4()?;
+
+ // Run the transformer
+ let (hs, src) = run_transformer(&src, &pos_src, &tokens)?;
+ let iou_token_out = hs.i((.., 0))?;
+ let mask_tokens_out = hs.i((.., 1, 1 + self.num_mask_tokens))?;
+
+ // Upscale mask embeddings and predict masks using the masks tokens.
+ let src = src.transpose(1, 2)?.reshape((b, c, h, w))?;
+ let upscaled_embedding = self
+ .output_upscaling_conv1
+ .forward(&src)?
+ .apply(&self.output_upscaling_ln)?
+ .gelu()?
+ .apply(&self.output_upscaling_conv2)?
+ .gelu()?;
+ let mut hyper_in_list = Vec::with_capacity(self.num_mask_tokens);
+ for (i, mlp) in self.output_hypernetworks_mlps.iter().enumerate() {
+ let h = mlp.forward(&mask_tokens_out.i((.., i))?)?;
+ hyper_in_list.push(h)
+ }
+ let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?;
+ let (b, c, h, w) = upscaled_embedding.dims4()?;
+ let masks = hyper_in
+ .matmul(&upscaled_embedding.reshape((b, c, h * w))?)?
+ .reshape((b, 0, h, w))?;
+
+ // Generate mask quality predictions.
+ let iou_pred = self.iou_prediction_head.forward(&iou_token_out)?;
+ Ok((masks, iou_pred))
+ }
+}
+
+// Equivalent to torch.repeat_interleave
+fn repeat_interleave(_img: &Tensor, _repeats: usize, _dim: usize) -> Result<Tensor> {
+ todo!()
+}
+
+fn run_transformer(_src: &Tensor, _pos: &Tensor, _tokens: &Tensor) -> Result<(Tensor, Tensor)> {
+ todo!()
+}