summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/segment-anything')
-rw-r--r--candle-examples/examples/segment-anything/main.rs47
-rw-r--r--candle-examples/examples/segment-anything/model_image_encoder.rs39
-rw-r--r--candle-examples/examples/segment-anything/model_mask_decoder.rs12
-rw-r--r--candle-examples/examples/segment-anything/model_prompt_encoder.rs4
4 files changed, 87 insertions, 15 deletions
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs
index c5095c0e..a749ba2a 100644
--- a/candle-examples/examples/segment-anything/main.rs
+++ b/candle-examples/examples/segment-anything/main.rs
@@ -14,15 +14,17 @@ pub mod model_sam;
pub mod model_transformer;
use candle::{DType, Result, Tensor};
-use candle_nn::{Linear, Module, VarBuilder};
+use candle_nn::{Module, VarBuilder};
use clap::Parser;
pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
- if bias {
- candle_nn::linear(in_dim, out_dim, vb)
+ let inner = if bias {
+ candle_nn::linear(in_dim, out_dim, vb)?
} else {
- candle_nn::linear_no_bias(in_dim, out_dim, vb)
- }
+ candle_nn::linear_no_bias(in_dim, out_dim, vb)?
+ };
+ let span = tracing::span!(tracing::Level::TRACE, "linear");
+ Ok(Linear { inner, span })
}
#[derive(Debug)]
@@ -62,6 +64,7 @@ pub struct MlpBlock {
lin1: Linear,
lin2: Linear,
activation: candle_nn::Activation,
+ span: tracing::Span,
}
impl MlpBlock {
@@ -71,24 +74,40 @@ impl MlpBlock {
activation: candle_nn::Activation,
vb: VarBuilder,
) -> Result<Self> {
- let lin1 = candle_nn::linear(embedding_dim, mlp_dim, vb.pp("lin1"))?;
- let lin2 = candle_nn::linear(mlp_dim, embedding_dim, vb.pp("lin2"))?;
+ let lin1 = linear(vb.pp("lin1"), embedding_dim, mlp_dim, true)?;
+ let lin2 = linear(vb.pp("lin2"), mlp_dim, embedding_dim, true)?;
+ let span = tracing::span!(tracing::Level::TRACE, "mlp-block");
Ok(Self {
lin1,
lin2,
activation,
+ span,
})
}
}
impl Module for MlpBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
xs.apply(&self.lin1)?
.apply(&self.activation)?
.apply(&self.lin2)
}
}
+#[derive(Debug)]
+pub struct Linear {
+ inner: candle_nn::Linear,
+ span: tracing::Span,
+}
+
+impl Module for Linear {
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(x)
+ }
+}
+
#[derive(Parser)]
struct Args {
#[arg(long)]
@@ -109,10 +128,24 @@ struct Args {
#[arg(long, default_value_t = 0.5)]
point_y: f64,
+
+ /// Enable tracing (generates a trace-timestamp.json file).
+ #[arg(long)]
+ tracing: bool,
}
pub fn main() -> anyhow::Result<()> {
+ use tracing_chrome::ChromeLayerBuilder;
+ use tracing_subscriber::prelude::*;
+
let args = Args::parse();
+ let _guard = if args.tracing {
+ let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
+ tracing_subscriber::registry().with(chrome_layer).init();
+ Some(guard)
+ } else {
+ None
+ };
let device = candle_examples::device(args.cpu)?;
diff --git a/candle-examples/examples/segment-anything/model_image_encoder.rs b/candle-examples/examples/segment-anything/model_image_encoder.rs
index f1b76e23..f997170d 100644
--- a/candle-examples/examples/segment-anything/model_image_encoder.rs
+++ b/candle-examples/examples/segment-anything/model_image_encoder.rs
@@ -1,9 +1,10 @@
use candle::{DType, IndexOp, Result, Tensor};
-use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
+use candle_nn::{layer_norm, LayerNorm, Module, VarBuilder};
#[derive(Debug)]
struct PatchEmbed {
proj: candle_nn::Conv2d,
+ span: tracing::Span,
}
impl PatchEmbed {
@@ -21,23 +22,28 @@ impl PatchEmbed {
..Default::default()
};
let proj = candle_nn::conv2d(in_chans, embed_dim, k_size, cfg, vb.pp("proj"))?;
- Ok(Self { proj })
+ let span = tracing::span!(tracing::Level::TRACE, "patch-embed");
+ Ok(Self { proj, span })
}
}
impl Module for PatchEmbed {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
xs.apply(&self.proj)?.permute((0, 2, 3, 1))
}
}
#[derive(Debug)]
struct Attention {
- qkv: Linear,
- proj: Linear,
+ qkv: crate::Linear,
+ proj: crate::Linear,
num_heads: usize,
scale: f64,
rel_pos_hw: Option<(Tensor, Tensor)>,
+ span: tracing::Span,
+ span_rel_pos: tracing::Span,
+ span_softmax: tracing::Span,
}
impl Attention {
@@ -49,6 +55,9 @@ impl Attention {
input_size: (usize, usize),
vb: VarBuilder,
) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "attention");
+ let span_rel_pos = tracing::span!(tracing::Level::TRACE, "attn-rel-pos");
+ let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm");
let qkv = crate::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
let proj = crate::linear(vb.pp("proj"), dim, dim, true)?;
let head_dim = dim / num_heads;
@@ -66,6 +75,9 @@ impl Attention {
num_heads,
scale,
rel_pos_hw,
+ span,
+ span_rel_pos,
+ span_softmax,
})
}
@@ -126,6 +138,7 @@ fn get_rel_pos(q_size: usize, k_size: usize, rel_pos: &Tensor) -> Result<Tensor>
impl Module for Attention {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
let (b, h, w, c) = xs.dims4()?;
let qkv = self
.qkv
@@ -137,8 +150,14 @@ impl Module for Attention {
let k = qkv.i(1)?;
let v = qkv.i(2)?;
let attn = (&q * self.scale)?.matmul(&k.t()?)?;
- let attn = self.add_decomposed_rel_pos(attn, &q, (h, w), (h, w))?;
- let attn = candle_nn::ops::softmax_last_dim(&attn)?;
+ let attn = {
+ let _enter = self.span_rel_pos.enter();
+ self.add_decomposed_rel_pos(attn, &q, (h, w), (h, w))?
+ };
+ let attn = {
+ let _enter = self.span_softmax.enter();
+ candle_nn::ops::softmax_last_dim(&attn)?
+ };
let attn = attn.matmul(&v)?;
let attn = attn
.reshape((b, self.num_heads, h, w, c / self.num_heads))?
@@ -155,6 +174,7 @@ struct Block {
norm2: LayerNorm,
mlp: crate::MlpBlock,
window_size: usize,
+ span: tracing::Span,
}
impl Block {
@@ -183,12 +203,14 @@ impl Block {
vb.pp("attn"),
)?;
let mlp = crate::MlpBlock::new(dim, dim * 4, candle_nn::Activation::Gelu, vb.pp("mlp"))?;
+ let span = tracing::span!(tracing::Level::TRACE, "ie-block");
Ok(Self {
norm1,
attn,
norm2,
mlp,
window_size,
+ span,
})
}
}
@@ -249,6 +271,7 @@ fn window_unpartition(
impl Module for Block {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
let shortcut = xs;
let xs = self.norm1.forward(xs)?;
let hw = (xs.dim(1)?, xs.dim(2)?);
@@ -277,6 +300,7 @@ pub struct ImageEncoderViT {
neck_conv2: candle_nn::Conv2d,
neck_ln2: crate::LayerNorm2d,
pos_embed: Option<Tensor>,
+ span: tracing::Span,
}
impl ImageEncoderViT {
@@ -346,6 +370,7 @@ impl ImageEncoderViT {
} else {
None
};
+ let span = tracing::span!(tracing::Level::TRACE, "image-encoder-vit");
Ok(Self {
patch_embed,
blocks,
@@ -354,12 +379,14 @@ impl ImageEncoderViT {
neck_conv2,
neck_ln2,
pos_embed,
+ span,
})
}
}
impl Module for ImageEncoderViT {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
let xs = self.patch_embed.forward(xs)?;
let mut xs = match &self.pos_embed {
Some(pos_embed) => (xs + pos_embed)?,
diff --git a/candle-examples/examples/segment-anything/model_mask_decoder.rs b/candle-examples/examples/segment-anything/model_mask_decoder.rs
index 598af1f6..1f6d62a4 100644
--- a/candle-examples/examples/segment-anything/model_mask_decoder.rs
+++ b/candle-examples/examples/segment-anything/model_mask_decoder.rs
@@ -1,12 +1,13 @@
use candle::{IndexOp, Result, Tensor};
-use candle_nn::{Linear, Module, VarBuilder};
+use candle_nn::{Module, VarBuilder};
use crate::model_transformer::TwoWayTransformer;
#[derive(Debug)]
struct MlpMaskDecoder {
- layers: Vec<Linear>,
+ layers: Vec<crate::Linear>,
sigmoid_output: bool,
+ span: tracing::Span,
}
impl MlpMaskDecoder {
@@ -30,15 +31,18 @@ impl MlpMaskDecoder {
let layer = crate::linear(vb.pp(i), in_dim, out_dim, true)?;
layers.push(layer)
}
+ let span = tracing::span!(tracing::Level::TRACE, "mlp-mask-decoder");
Ok(Self {
layers,
sigmoid_output,
+ span,
})
}
}
impl Module for MlpMaskDecoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
let mut xs = xs.clone();
for (i, layer) in self.layers.iter().enumerate() {
xs = layer.forward(&xs)?;
@@ -65,6 +69,7 @@ pub struct MaskDecoder {
num_mask_tokens: usize,
output_hypernetworks_mlps: Vec<MlpMaskDecoder>,
transformer: TwoWayTransformer,
+ span: tracing::Span,
}
impl MaskDecoder {
@@ -127,6 +132,7 @@ impl MaskDecoder {
/* mlp_dim */ 2048,
vb.pp("transformer"),
)?;
+ let span = tracing::span!(tracing::Level::TRACE, "mask-decoder");
Ok(Self {
iou_token,
mask_tokens,
@@ -137,6 +143,7 @@ impl MaskDecoder {
num_mask_tokens,
output_hypernetworks_mlps,
transformer,
+ span,
})
}
@@ -148,6 +155,7 @@ impl MaskDecoder {
dense_prompt_embeddings: &Tensor,
multimask_output: bool,
) -> Result<(Tensor, Tensor)> {
+ let _enter = self.span.enter();
let (masks, iou_pred) = self.predict_masks(
image_embeddings,
image_pe,
diff --git a/candle-examples/examples/segment-anything/model_prompt_encoder.rs b/candle-examples/examples/segment-anything/model_prompt_encoder.rs
index b401a900..40cc6e36 100644
--- a/candle-examples/examples/segment-anything/model_prompt_encoder.rs
+++ b/candle-examples/examples/segment-anything/model_prompt_encoder.rs
@@ -64,6 +64,7 @@ pub struct PromptEncoder {
image_embedding_size: (usize, usize),
input_image_size: (usize, usize),
embed_dim: usize,
+ span: tracing::Span,
}
impl PromptEncoder {
@@ -108,6 +109,7 @@ impl PromptEncoder {
let emb = candle_nn::embedding(1, embed_dim, vb_e.pp(i))?;
point_embeddings.push(emb)
}
+ let span = tracing::span!(tracing::Level::TRACE, "prompt-encoder");
Ok(Self {
pe_layer,
point_embeddings,
@@ -121,6 +123,7 @@ impl PromptEncoder {
image_embedding_size,
input_image_size,
embed_dim,
+ span,
})
}
@@ -201,6 +204,7 @@ impl PromptEncoder {
boxes: Option<&Tensor>,
masks: Option<&Tensor>,
) -> Result<(Tensor, Tensor)> {
+ let _enter = self.span.enter();
let se_points = match points {
Some((coords, labels)) => Some(self.embed_points(coords, labels, boxes.is_none())?),
None => None,