diff options
Diffstat (limited to 'candle-examples/examples/segment-anything/model_mask_decoder.rs')
-rw-r--r-- | candle-examples/examples/segment-anything/model_mask_decoder.rs | 222 |
1 files changed, 222 insertions, 0 deletions
diff --git a/candle-examples/examples/segment-anything/model_mask_decoder.rs b/candle-examples/examples/segment-anything/model_mask_decoder.rs new file mode 100644 index 00000000..55a006c4 --- /dev/null +++ b/candle-examples/examples/segment-anything/model_mask_decoder.rs @@ -0,0 +1,222 @@ +use candle::{DType, IndexOp, Result, Tensor, D}; +use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; + +#[derive(Debug)] +struct MlpMaskDecoder { + layers: Vec<Linear>, + sigmoid_output: bool, +} + +impl MlpMaskDecoder { + fn new( + input_dim: usize, + hidden_dim: usize, + output_dim: usize, + num_layers: usize, + sigmoid_output: bool, + vb: VarBuilder, + ) -> Result<Self> { + let mut layers = Vec::with_capacity(num_layers); + let vb = vb.pp("layers"); + for i in 0..num_layers { + let in_dim = if i == 0 { input_dim } else { hidden_dim }; + let out_dim = if i + 1 == num_layers { + output_dim + } else { + hidden_dim + }; + let layer = crate::linear(vb.pp(i), in_dim, out_dim, true)?; + layers.push(layer) + } + Ok(Self { + layers, + sigmoid_output, + }) + } +} + +impl Module for MlpMaskDecoder { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let mut xs = xs.clone(); + for (i, layer) in self.layers.iter().enumerate() { + xs = layer.forward(&xs)?; + if i + 1 < self.layers.len() { + xs = xs.relu()? + } + } + if self.sigmoid_output { + candle_nn::ops::sigmoid(&xs) + } else { + Ok(xs) + } + } +} + +#[derive(Debug)] +struct MaskDecoder { + iou_token: candle_nn::Embedding, + mask_tokens: candle_nn::Embedding, + iou_prediction_head: MlpMaskDecoder, + output_upscaling_conv1: candle_nn::ConvTranspose2d, + output_upscaling_ln: LayerNorm, + output_upscaling_conv2: candle_nn::ConvTranspose2d, + num_mask_tokens: usize, + output_hypernetworks_mlps: Vec<MlpMaskDecoder>, +} + +impl MaskDecoder { + fn new( + transformer_dim: usize, + num_multimask_outputs: usize, + iou_head_depth: usize, + iou_head_hidden_dim: usize, + vb: VarBuilder, + ) -> Result<Self> { + let num_mask_tokens = num_multimask_outputs - 1; + let iou_prediction_head = MlpMaskDecoder::new( + transformer_dim, + iou_head_hidden_dim, + num_mask_tokens, + iou_head_depth, + false, + vb.pp("iou_prediction_head"), + )?; + let iou_token = candle_nn::embedding(1, transformer_dim, vb.pp("iou_token"))?; + let mask_tokens = + candle_nn::embedding(num_mask_tokens, transformer_dim, vb.pp("mask_tokens"))?; + let cfg = candle_nn::ConvTranspose2dConfig { + stride: 2, + ..Default::default() + }; + let output_upscaling_conv1 = candle_nn::conv_transpose2d( + transformer_dim, + transformer_dim / 4, + 2, + cfg, + vb.pp("output_upscaling.0"), + )?; + let output_upscaling_ln = + layer_norm(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?; + let output_upscaling_conv2 = candle_nn::conv_transpose2d( + transformer_dim / 4, + transformer_dim / 8, + 2, + cfg, + vb.pp("output_upscaling.3"), + )?; + let mut output_hypernetworks_mlps = Vec::with_capacity(num_mask_tokens); + let vb_o = vb.pp("output_hypernetworks_mlps"); + for i in 0..num_mask_tokens { + let mlp = MlpMaskDecoder::new( + transformer_dim, + transformer_dim, + transformer_dim / 8, + 3, + false, + vb_o.pp(i), + )?; + output_hypernetworks_mlps.push(mlp) + } + Ok(Self { + iou_token, + mask_tokens, + iou_prediction_head, + output_upscaling_conv1, + output_upscaling_ln, + output_upscaling_conv2, + num_mask_tokens, + output_hypernetworks_mlps, + }) + } + + fn forward( + &self, + image_embeddings: &Tensor, + image_pe: &Tensor, + sparse_prompt_embeddings: &Tensor, + dense_prompt_embeddings: &Tensor, + multimask_output: bool, + ) -> Result<(Tensor, Tensor)> { + let (masks, iou_pred) = self.predict_masks( + image_embeddings, + image_pe, + sparse_prompt_embeddings, + dense_prompt_embeddings, + )?; + let masks = if multimask_output { + masks.i((.., 1..))? + } else { + masks.i((.., 0..1))? + }; + let iou_pred = if multimask_output { + iou_pred.i((.., 1..))? + } else { + iou_pred.i((.., 0..1))? + }; + Ok((masks, iou_pred)) + } + + fn predict_masks( + &self, + image_embeddings: &Tensor, + image_pe: &Tensor, + sparse_prompt_embeddings: &Tensor, + dense_prompt_embeddings: &Tensor, + ) -> Result<(Tensor, Tensor)> { + // Concatenate ouput tokens. + let output_tokens = Tensor::cat( + &[self.iou_token.embeddings(), self.mask_tokens.embeddings()], + 0, + )?; + let (d1, d2) = output_tokens.dims2()?; + let output_tokens = + output_tokens + .unsqueeze(0)? + .expand((sparse_prompt_embeddings.dim(0)?, d1, d2))?; + let tokens = Tensor::cat(&[&output_tokens, sparse_prompt_embeddings], 1)?; + + // Expand per-image data in batch direction to be per mask + let src = repeat_interleave(image_embeddings, tokens.dim(0)?, 0)?; + let src = (src + dense_prompt_embeddings)?; + let pos_src = repeat_interleave(image_pe, tokens.dim(0)?, 0)?; + let (b, c, h, w) = src.dims4()?; + + // Run the transformer + let (hs, src) = run_transformer(&src, &pos_src, &tokens)?; + let iou_token_out = hs.i((.., 0))?; + let mask_tokens_out = hs.i((.., 1, 1 + self.num_mask_tokens))?; + + // Upscale mask embeddings and predict masks using the masks tokens. + let src = src.transpose(1, 2)?.reshape((b, c, h, w))?; + let upscaled_embedding = self + .output_upscaling_conv1 + .forward(&src)? + .apply(&self.output_upscaling_ln)? + .gelu()? + .apply(&self.output_upscaling_conv2)? + .gelu()?; + let mut hyper_in_list = Vec::with_capacity(self.num_mask_tokens); + for (i, mlp) in self.output_hypernetworks_mlps.iter().enumerate() { + let h = mlp.forward(&mask_tokens_out.i((.., i))?)?; + hyper_in_list.push(h) + } + let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?; + let (b, c, h, w) = upscaled_embedding.dims4()?; + let masks = hyper_in + .matmul(&upscaled_embedding.reshape((b, c, h * w))?)? + .reshape((b, 0, h, w))?; + + // Generate mask quality predictions. + let iou_pred = self.iou_prediction_head.forward(&iou_token_out)?; + Ok((masks, iou_pred)) + } +} + +// Equivalent to torch.repeat_interleave +fn repeat_interleave(_img: &Tensor, _repeats: usize, _dim: usize) -> Result<Tensor> { + todo!() +} + +fn run_transformer(_src: &Tensor, _pos: &Tensor, _tokens: &Tensor) -> Result<(Tensor, Tensor)> { + todo!() +} |