diff options
Diffstat (limited to 'candle-examples/examples/segment-anything/model_mask_decoder.rs')
-rw-r--r-- | candle-examples/examples/segment-anything/model_mask_decoder.rs | 22 |
1 files changed, 12 insertions, 10 deletions
diff --git a/candle-examples/examples/segment-anything/model_mask_decoder.rs b/candle-examples/examples/segment-anything/model_mask_decoder.rs index cf3879cd..1ef46eeb 100644 --- a/candle-examples/examples/segment-anything/model_mask_decoder.rs +++ b/candle-examples/examples/segment-anything/model_mask_decoder.rs @@ -1,5 +1,5 @@ use candle::{DType, IndexOp, Result, Tensor, D}; -use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; +use candle_nn::{Linear, Module, VarBuilder}; use crate::model_transformer::TwoWayTransformer; @@ -60,7 +60,7 @@ pub struct MaskDecoder { mask_tokens: candle_nn::Embedding, iou_prediction_head: MlpMaskDecoder, output_upscaling_conv1: candle_nn::ConvTranspose2d, - output_upscaling_ln: LayerNorm, + output_upscaling_ln: crate::LayerNorm2d, output_upscaling_conv2: candle_nn::ConvTranspose2d, num_mask_tokens: usize, output_hypernetworks_mlps: Vec<MlpMaskDecoder>, @@ -99,7 +99,7 @@ impl MaskDecoder { vb.pp("output_upscaling.0"), )?; let output_upscaling_ln = - layer_norm(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?; + crate::LayerNorm2d::new(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?; let output_upscaling_conv2 = candle_nn::conv_transpose2d( transformer_dim / 4, transformer_dim / 8, @@ -140,7 +140,7 @@ impl MaskDecoder { }) } - fn forward( + pub fn forward( &self, image_embeddings: &Tensor, image_pe: &Tensor, @@ -195,7 +195,7 @@ impl MaskDecoder { // Run the transformer let (hs, src) = self.transformer.forward(&src, &pos_src, &tokens)?; let iou_token_out = hs.i((.., 0))?; - let mask_tokens_out = hs.i((.., 1, 1 + self.num_mask_tokens))?; + let mask_tokens_out = hs.i((.., 1..1 + self.num_mask_tokens))?; // Upscale mask embeddings and predict masks using the masks tokens. let src = src.transpose(1, 2)?.reshape((b, c, h, w))?; @@ -213,9 +213,8 @@ impl MaskDecoder { } let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?; let (b, c, h, w) = upscaled_embedding.dims4()?; - let masks = hyper_in - .matmul(&upscaled_embedding.reshape((b, c, h * w))?)? - .reshape((b, 0, h, w))?; + let masks = hyper_in.matmul(&upscaled_embedding.reshape((b, c, h * w))?)?; + let masks = masks.reshape((b, masks.elem_count() / b / h / w, h, w))?; // Generate mask quality predictions. let iou_pred = self.iou_prediction_head.forward(&iou_token_out)?; @@ -224,6 +223,9 @@ impl MaskDecoder { } // Equivalent to torch.repeat_interleave -fn repeat_interleave(_img: &Tensor, _repeats: usize, _dim: usize) -> Result<Tensor> { - todo!() +fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result<Tensor> { + let img = img.unsqueeze(dim + 1)?; + let mut dims = img.dims().to_vec(); + dims[dim + 1] = repeats; + img.broadcast_as(dims)?.flatten(dim, dim + 1) } |