diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-08 15:31:29 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-08 15:31:29 +0100 |
commit | 158ff3c609b22ed998dea5283738cc1ed13aa592 (patch) | |
tree | ab8f89575a1e9322898147132dc44a676b3e991a /candle-examples/examples/segment-anything/model_image_encoder.rs | |
parent | e5703d2f56ce24652e7ae85dc74484681e4dbcb9 (diff) | |
download | candle-158ff3c609b22ed998dea5283738cc1ed13aa592.tar.gz candle-158ff3c609b22ed998dea5283738cc1ed13aa592.tar.bz2 candle-158ff3c609b22ed998dea5283738cc1ed13aa592.zip |
Add tracing to segment-anything (#777)
* Tracing support for segment-anything.
* More tracing.
* Handle the empty slice case.
Diffstat (limited to 'candle-examples/examples/segment-anything/model_image_encoder.rs')
-rw-r--r-- | candle-examples/examples/segment-anything/model_image_encoder.rs | 39 |
1 files changed, 33 insertions, 6 deletions
diff --git a/candle-examples/examples/segment-anything/model_image_encoder.rs b/candle-examples/examples/segment-anything/model_image_encoder.rs index f1b76e23..f997170d 100644 --- a/candle-examples/examples/segment-anything/model_image_encoder.rs +++ b/candle-examples/examples/segment-anything/model_image_encoder.rs @@ -1,9 +1,10 @@ use candle::{DType, IndexOp, Result, Tensor}; -use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; +use candle_nn::{layer_norm, LayerNorm, Module, VarBuilder}; #[derive(Debug)] struct PatchEmbed { proj: candle_nn::Conv2d, + span: tracing::Span, } impl PatchEmbed { @@ -21,23 +22,28 @@ impl PatchEmbed { ..Default::default() }; let proj = candle_nn::conv2d(in_chans, embed_dim, k_size, cfg, vb.pp("proj"))?; - Ok(Self { proj }) + let span = tracing::span!(tracing::Level::TRACE, "patch-embed"); + Ok(Self { proj, span }) } } impl Module for PatchEmbed { fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); xs.apply(&self.proj)?.permute((0, 2, 3, 1)) } } #[derive(Debug)] struct Attention { - qkv: Linear, - proj: Linear, + qkv: crate::Linear, + proj: crate::Linear, num_heads: usize, scale: f64, rel_pos_hw: Option<(Tensor, Tensor)>, + span: tracing::Span, + span_rel_pos: tracing::Span, + span_softmax: tracing::Span, } impl Attention { @@ -49,6 +55,9 @@ impl Attention { input_size: (usize, usize), vb: VarBuilder, ) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "attention"); + let span_rel_pos = tracing::span!(tracing::Level::TRACE, "attn-rel-pos"); + let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm"); let qkv = crate::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?; let proj = crate::linear(vb.pp("proj"), dim, dim, true)?; let head_dim = dim / num_heads; @@ -66,6 +75,9 @@ impl Attention { num_heads, scale, rel_pos_hw, + span, + span_rel_pos, + span_softmax, }) } @@ -126,6 +138,7 @@ fn get_rel_pos(q_size: usize, k_size: usize, rel_pos: &Tensor) -> Result<Tensor> impl Module for Attention { fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); let (b, h, w, c) = xs.dims4()?; let qkv = self .qkv @@ -137,8 +150,14 @@ impl Module for Attention { let k = qkv.i(1)?; let v = qkv.i(2)?; let attn = (&q * self.scale)?.matmul(&k.t()?)?; - let attn = self.add_decomposed_rel_pos(attn, &q, (h, w), (h, w))?; - let attn = candle_nn::ops::softmax_last_dim(&attn)?; + let attn = { + let _enter = self.span_rel_pos.enter(); + self.add_decomposed_rel_pos(attn, &q, (h, w), (h, w))? + }; + let attn = { + let _enter = self.span_softmax.enter(); + candle_nn::ops::softmax_last_dim(&attn)? + }; let attn = attn.matmul(&v)?; let attn = attn .reshape((b, self.num_heads, h, w, c / self.num_heads))? @@ -155,6 +174,7 @@ struct Block { norm2: LayerNorm, mlp: crate::MlpBlock, window_size: usize, + span: tracing::Span, } impl Block { @@ -183,12 +203,14 @@ impl Block { vb.pp("attn"), )?; let mlp = crate::MlpBlock::new(dim, dim * 4, candle_nn::Activation::Gelu, vb.pp("mlp"))?; + let span = tracing::span!(tracing::Level::TRACE, "ie-block"); Ok(Self { norm1, attn, norm2, mlp, window_size, + span, }) } } @@ -249,6 +271,7 @@ fn window_unpartition( impl Module for Block { fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); let shortcut = xs; let xs = self.norm1.forward(xs)?; let hw = (xs.dim(1)?, xs.dim(2)?); @@ -277,6 +300,7 @@ pub struct ImageEncoderViT { neck_conv2: candle_nn::Conv2d, neck_ln2: crate::LayerNorm2d, pos_embed: Option<Tensor>, + span: tracing::Span, } impl ImageEncoderViT { @@ -346,6 +370,7 @@ impl ImageEncoderViT { } else { None }; + let span = tracing::span!(tracing::Level::TRACE, "image-encoder-vit"); Ok(Self { patch_embed, blocks, @@ -354,12 +379,14 @@ impl ImageEncoderViT { neck_conv2, neck_ln2, pos_embed, + span, }) } } impl Module for ImageEncoderViT { fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); let xs = self.patch_embed.forward(xs)?; let mut xs = match &self.pos_embed { Some(pos_embed) => (xs + pos_embed)?, |