diff options
Diffstat (limited to 'candle-examples/examples/segment-anything/model_mask_decoder.rs')
-rw-r--r-- | candle-examples/examples/segment-anything/model_mask_decoder.rs | 12 |
1 files changed, 10 insertions, 2 deletions
diff --git a/candle-examples/examples/segment-anything/model_mask_decoder.rs b/candle-examples/examples/segment-anything/model_mask_decoder.rs index 598af1f6..1f6d62a4 100644 --- a/candle-examples/examples/segment-anything/model_mask_decoder.rs +++ b/candle-examples/examples/segment-anything/model_mask_decoder.rs @@ -1,12 +1,13 @@ use candle::{IndexOp, Result, Tensor}; -use candle_nn::{Linear, Module, VarBuilder}; +use candle_nn::{Module, VarBuilder}; use crate::model_transformer::TwoWayTransformer; #[derive(Debug)] struct MlpMaskDecoder { - layers: Vec<Linear>, + layers: Vec<crate::Linear>, sigmoid_output: bool, + span: tracing::Span, } impl MlpMaskDecoder { @@ -30,15 +31,18 @@ impl MlpMaskDecoder { let layer = crate::linear(vb.pp(i), in_dim, out_dim, true)?; layers.push(layer) } + let span = tracing::span!(tracing::Level::TRACE, "mlp-mask-decoder"); Ok(Self { layers, sigmoid_output, + span, }) } } impl Module for MlpMaskDecoder { fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); let mut xs = xs.clone(); for (i, layer) in self.layers.iter().enumerate() { xs = layer.forward(&xs)?; @@ -65,6 +69,7 @@ pub struct MaskDecoder { num_mask_tokens: usize, output_hypernetworks_mlps: Vec<MlpMaskDecoder>, transformer: TwoWayTransformer, + span: tracing::Span, } impl MaskDecoder { @@ -127,6 +132,7 @@ impl MaskDecoder { /* mlp_dim */ 2048, vb.pp("transformer"), )?; + let span = tracing::span!(tracing::Level::TRACE, "mask-decoder"); Ok(Self { iou_token, mask_tokens, @@ -137,6 +143,7 @@ impl MaskDecoder { num_mask_tokens, output_hypernetworks_mlps, transformer, + span, }) } @@ -148,6 +155,7 @@ impl MaskDecoder { dense_prompt_embeddings: &Tensor, multimask_output: bool, ) -> Result<(Tensor, Tensor)> { + let _enter = self.span.enter(); let (masks, iou_pred) = self.predict_masks( image_embeddings, image_pe, |