summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything/model_prompt_encoder.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/segment-anything/model_prompt_encoder.rs')
-rw-r--r--candle-examples/examples/segment-anything/model_prompt_encoder.rs4
1 files changed, 4 insertions, 0 deletions
diff --git a/candle-examples/examples/segment-anything/model_prompt_encoder.rs b/candle-examples/examples/segment-anything/model_prompt_encoder.rs
index b401a900..40cc6e36 100644
--- a/candle-examples/examples/segment-anything/model_prompt_encoder.rs
+++ b/candle-examples/examples/segment-anything/model_prompt_encoder.rs
@@ -64,6 +64,7 @@ pub struct PromptEncoder {
image_embedding_size: (usize, usize),
input_image_size: (usize, usize),
embed_dim: usize,
+ span: tracing::Span,
}
impl PromptEncoder {
@@ -108,6 +109,7 @@ impl PromptEncoder {
let emb = candle_nn::embedding(1, embed_dim, vb_e.pp(i))?;
point_embeddings.push(emb)
}
+ let span = tracing::span!(tracing::Level::TRACE, "prompt-encoder");
Ok(Self {
pe_layer,
point_embeddings,
@@ -121,6 +123,7 @@ impl PromptEncoder {
image_embedding_size,
input_image_size,
embed_dim,
+ span,
})
}
@@ -201,6 +204,7 @@ impl PromptEncoder {
boxes: Option<&Tensor>,
masks: Option<&Tensor>,
) -> Result<(Tensor, Tensor)> {
+ let _enter = self.span.enter();
let se_points = match points {
Some((coords, labels)) => Some(self.embed_points(coords, labels, boxes.is_none())?),
None => None,