summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything/model_image_encoder.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-08 15:31:29 +0100
committerGitHub <noreply@github.com>2023-09-08 15:31:29 +0100
commit158ff3c609b22ed998dea5283738cc1ed13aa592 (patch)
treeab8f89575a1e9322898147132dc44a676b3e991a /candle-examples/examples/segment-anything/model_image_encoder.rs
parente5703d2f56ce24652e7ae85dc74484681e4dbcb9 (diff)
downloadcandle-158ff3c609b22ed998dea5283738cc1ed13aa592.tar.gz
candle-158ff3c609b22ed998dea5283738cc1ed13aa592.tar.bz2
candle-158ff3c609b22ed998dea5283738cc1ed13aa592.zip
Add tracing to segment-anything (#777)
* Tracing support for segment-anything. * More tracing. * Handle the empty slice case.
Diffstat (limited to 'candle-examples/examples/segment-anything/model_image_encoder.rs')
-rw-r--r--candle-examples/examples/segment-anything/model_image_encoder.rs39
1 files changed, 33 insertions, 6 deletions
diff --git a/candle-examples/examples/segment-anything/model_image_encoder.rs b/candle-examples/examples/segment-anything/model_image_encoder.rs
index f1b76e23..f997170d 100644
--- a/candle-examples/examples/segment-anything/model_image_encoder.rs
+++ b/candle-examples/examples/segment-anything/model_image_encoder.rs
@@ -1,9 +1,10 @@
use candle::{DType, IndexOp, Result, Tensor};
-use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
+use candle_nn::{layer_norm, LayerNorm, Module, VarBuilder};
#[derive(Debug)]
struct PatchEmbed {
proj: candle_nn::Conv2d,
+ span: tracing::Span,
}
impl PatchEmbed {
@@ -21,23 +22,28 @@ impl PatchEmbed {
..Default::default()
};
let proj = candle_nn::conv2d(in_chans, embed_dim, k_size, cfg, vb.pp("proj"))?;
- Ok(Self { proj })
+ let span = tracing::span!(tracing::Level::TRACE, "patch-embed");
+ Ok(Self { proj, span })
}
}
impl Module for PatchEmbed {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
xs.apply(&self.proj)?.permute((0, 2, 3, 1))
}
}
#[derive(Debug)]
struct Attention {
- qkv: Linear,
- proj: Linear,
+ qkv: crate::Linear,
+ proj: crate::Linear,
num_heads: usize,
scale: f64,
rel_pos_hw: Option<(Tensor, Tensor)>,
+ span: tracing::Span,
+ span_rel_pos: tracing::Span,
+ span_softmax: tracing::Span,
}
impl Attention {
@@ -49,6 +55,9 @@ impl Attention {
input_size: (usize, usize),
vb: VarBuilder,
) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "attention");
+ let span_rel_pos = tracing::span!(tracing::Level::TRACE, "attn-rel-pos");
+ let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm");
let qkv = crate::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
let proj = crate::linear(vb.pp("proj"), dim, dim, true)?;
let head_dim = dim / num_heads;
@@ -66,6 +75,9 @@ impl Attention {
num_heads,
scale,
rel_pos_hw,
+ span,
+ span_rel_pos,
+ span_softmax,
})
}
@@ -126,6 +138,7 @@ fn get_rel_pos(q_size: usize, k_size: usize, rel_pos: &Tensor) -> Result<Tensor>
impl Module for Attention {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
let (b, h, w, c) = xs.dims4()?;
let qkv = self
.qkv
@@ -137,8 +150,14 @@ impl Module for Attention {
let k = qkv.i(1)?;
let v = qkv.i(2)?;
let attn = (&q * self.scale)?.matmul(&k.t()?)?;
- let attn = self.add_decomposed_rel_pos(attn, &q, (h, w), (h, w))?;
- let attn = candle_nn::ops::softmax_last_dim(&attn)?;
+ let attn = {
+ let _enter = self.span_rel_pos.enter();
+ self.add_decomposed_rel_pos(attn, &q, (h, w), (h, w))?
+ };
+ let attn = {
+ let _enter = self.span_softmax.enter();
+ candle_nn::ops::softmax_last_dim(&attn)?
+ };
let attn = attn.matmul(&v)?;
let attn = attn
.reshape((b, self.num_heads, h, w, c / self.num_heads))?
@@ -155,6 +174,7 @@ struct Block {
norm2: LayerNorm,
mlp: crate::MlpBlock,
window_size: usize,
+ span: tracing::Span,
}
impl Block {
@@ -183,12 +203,14 @@ impl Block {
vb.pp("attn"),
)?;
let mlp = crate::MlpBlock::new(dim, dim * 4, candle_nn::Activation::Gelu, vb.pp("mlp"))?;
+ let span = tracing::span!(tracing::Level::TRACE, "ie-block");
Ok(Self {
norm1,
attn,
norm2,
mlp,
window_size,
+ span,
})
}
}
@@ -249,6 +271,7 @@ fn window_unpartition(
impl Module for Block {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
let shortcut = xs;
let xs = self.norm1.forward(xs)?;
let hw = (xs.dim(1)?, xs.dim(2)?);
@@ -277,6 +300,7 @@ pub struct ImageEncoderViT {
neck_conv2: candle_nn::Conv2d,
neck_ln2: crate::LayerNorm2d,
pos_embed: Option<Tensor>,
+ span: tracing::Span,
}
impl ImageEncoderViT {
@@ -346,6 +370,7 @@ impl ImageEncoderViT {
} else {
None
};
+ let span = tracing::span!(tracing::Level::TRACE, "image-encoder-vit");
Ok(Self {
patch_embed,
blocks,
@@ -354,12 +379,14 @@ impl ImageEncoderViT {
neck_conv2,
neck_ln2,
pos_embed,
+ span,
})
}
}
impl Module for ImageEncoderViT {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
let xs = self.patch_embed.forward(xs)?;
let mut xs = match &self.pos_embed {
Some(pos_embed) => (xs + pos_embed)?,