summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything/model_mask_decoder.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/segment-anything/model_mask_decoder.rs')
-rw-r--r--candle-examples/examples/segment-anything/model_mask_decoder.rs22
1 files changed, 12 insertions, 10 deletions
diff --git a/candle-examples/examples/segment-anything/model_mask_decoder.rs b/candle-examples/examples/segment-anything/model_mask_decoder.rs
index cf3879cd..1ef46eeb 100644
--- a/candle-examples/examples/segment-anything/model_mask_decoder.rs
+++ b/candle-examples/examples/segment-anything/model_mask_decoder.rs
@@ -1,5 +1,5 @@
use candle::{DType, IndexOp, Result, Tensor, D};
-use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
+use candle_nn::{Linear, Module, VarBuilder};
use crate::model_transformer::TwoWayTransformer;
@@ -60,7 +60,7 @@ pub struct MaskDecoder {
mask_tokens: candle_nn::Embedding,
iou_prediction_head: MlpMaskDecoder,
output_upscaling_conv1: candle_nn::ConvTranspose2d,
- output_upscaling_ln: LayerNorm,
+ output_upscaling_ln: crate::LayerNorm2d,
output_upscaling_conv2: candle_nn::ConvTranspose2d,
num_mask_tokens: usize,
output_hypernetworks_mlps: Vec<MlpMaskDecoder>,
@@ -99,7 +99,7 @@ impl MaskDecoder {
vb.pp("output_upscaling.0"),
)?;
let output_upscaling_ln =
- layer_norm(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?;
+ crate::LayerNorm2d::new(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?;
let output_upscaling_conv2 = candle_nn::conv_transpose2d(
transformer_dim / 4,
transformer_dim / 8,
@@ -140,7 +140,7 @@ impl MaskDecoder {
})
}
- fn forward(
+ pub fn forward(
&self,
image_embeddings: &Tensor,
image_pe: &Tensor,
@@ -195,7 +195,7 @@ impl MaskDecoder {
// Run the transformer
let (hs, src) = self.transformer.forward(&src, &pos_src, &tokens)?;
let iou_token_out = hs.i((.., 0))?;
- let mask_tokens_out = hs.i((.., 1, 1 + self.num_mask_tokens))?;
+ let mask_tokens_out = hs.i((.., 1..1 + self.num_mask_tokens))?;
// Upscale mask embeddings and predict masks using the masks tokens.
let src = src.transpose(1, 2)?.reshape((b, c, h, w))?;
@@ -213,9 +213,8 @@ impl MaskDecoder {
}
let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?;
let (b, c, h, w) = upscaled_embedding.dims4()?;
- let masks = hyper_in
- .matmul(&upscaled_embedding.reshape((b, c, h * w))?)?
- .reshape((b, 0, h, w))?;
+ let masks = hyper_in.matmul(&upscaled_embedding.reshape((b, c, h * w))?)?;
+ let masks = masks.reshape((b, masks.elem_count() / b / h / w, h, w))?;
// Generate mask quality predictions.
let iou_pred = self.iou_prediction_head.forward(&iou_token_out)?;
@@ -224,6 +223,9 @@ impl MaskDecoder {
}
// Equivalent to torch.repeat_interleave
-fn repeat_interleave(_img: &Tensor, _repeats: usize, _dim: usize) -> Result<Tensor> {
- todo!()
+fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result<Tensor> {
+ let img = img.unsqueeze(dim + 1)?;
+ let mut dims = img.dims().to_vec();
+ dims[dim + 1] = repeats;
+ img.broadcast_as(dims)?.flatten(dim, dim + 1)
}