summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything/model_mask_decoder.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-08 15:31:29 +0100
committerGitHub <noreply@github.com>2023-09-08 15:31:29 +0100
commit158ff3c609b22ed998dea5283738cc1ed13aa592 (patch)
treeab8f89575a1e9322898147132dc44a676b3e991a /candle-examples/examples/segment-anything/model_mask_decoder.rs
parente5703d2f56ce24652e7ae85dc74484681e4dbcb9 (diff)
downloadcandle-158ff3c609b22ed998dea5283738cc1ed13aa592.tar.gz
candle-158ff3c609b22ed998dea5283738cc1ed13aa592.tar.bz2
candle-158ff3c609b22ed998dea5283738cc1ed13aa592.zip
Add tracing to segment-anything (#777)
* Tracing support for segment-anything. * More tracing. * Handle the empty slice case.
Diffstat (limited to 'candle-examples/examples/segment-anything/model_mask_decoder.rs')
-rw-r--r--candle-examples/examples/segment-anything/model_mask_decoder.rs12
1 files changed, 10 insertions, 2 deletions
diff --git a/candle-examples/examples/segment-anything/model_mask_decoder.rs b/candle-examples/examples/segment-anything/model_mask_decoder.rs
index 598af1f6..1f6d62a4 100644
--- a/candle-examples/examples/segment-anything/model_mask_decoder.rs
+++ b/candle-examples/examples/segment-anything/model_mask_decoder.rs
@@ -1,12 +1,13 @@
use candle::{IndexOp, Result, Tensor};
-use candle_nn::{Linear, Module, VarBuilder};
+use candle_nn::{Module, VarBuilder};
use crate::model_transformer::TwoWayTransformer;
#[derive(Debug)]
struct MlpMaskDecoder {
- layers: Vec<Linear>,
+ layers: Vec<crate::Linear>,
sigmoid_output: bool,
+ span: tracing::Span,
}
impl MlpMaskDecoder {
@@ -30,15 +31,18 @@ impl MlpMaskDecoder {
let layer = crate::linear(vb.pp(i), in_dim, out_dim, true)?;
layers.push(layer)
}
+ let span = tracing::span!(tracing::Level::TRACE, "mlp-mask-decoder");
Ok(Self {
layers,
sigmoid_output,
+ span,
})
}
}
impl Module for MlpMaskDecoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
let mut xs = xs.clone();
for (i, layer) in self.layers.iter().enumerate() {
xs = layer.forward(&xs)?;
@@ -65,6 +69,7 @@ pub struct MaskDecoder {
num_mask_tokens: usize,
output_hypernetworks_mlps: Vec<MlpMaskDecoder>,
transformer: TwoWayTransformer,
+ span: tracing::Span,
}
impl MaskDecoder {
@@ -127,6 +132,7 @@ impl MaskDecoder {
/* mlp_dim */ 2048,
vb.pp("transformer"),
)?;
+ let span = tracing::span!(tracing::Level::TRACE, "mask-decoder");
Ok(Self {
iou_token,
mask_tokens,
@@ -137,6 +143,7 @@ impl MaskDecoder {
num_mask_tokens,
output_hypernetworks_mlps,
transformer,
+ span,
})
}
@@ -148,6 +155,7 @@ impl MaskDecoder {
dense_prompt_embeddings: &Tensor,
multimask_output: bool,
) -> Result<(Tensor, Tensor)> {
+ let _enter = self.span.enter();
let (masks, iou_pred) = self.predict_masks(
image_embeddings,
image_pe,