diff options
Diffstat (limited to 'candle-examples/examples/segment-anything/main.rs')
-rw-r--r-- | candle-examples/examples/segment-anything/main.rs | 47 |
1 files changed, 40 insertions, 7 deletions
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs index c5095c0e..a749ba2a 100644 --- a/candle-examples/examples/segment-anything/main.rs +++ b/candle-examples/examples/segment-anything/main.rs @@ -14,15 +14,17 @@ pub mod model_sam; pub mod model_transformer; use candle::{DType, Result, Tensor}; -use candle_nn::{Linear, Module, VarBuilder}; +use candle_nn::{Module, VarBuilder}; use clap::Parser; pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> { - if bias { - candle_nn::linear(in_dim, out_dim, vb) + let inner = if bias { + candle_nn::linear(in_dim, out_dim, vb)? } else { - candle_nn::linear_no_bias(in_dim, out_dim, vb) - } + candle_nn::linear_no_bias(in_dim, out_dim, vb)? + }; + let span = tracing::span!(tracing::Level::TRACE, "linear"); + Ok(Linear { inner, span }) } #[derive(Debug)] @@ -62,6 +64,7 @@ pub struct MlpBlock { lin1: Linear, lin2: Linear, activation: candle_nn::Activation, + span: tracing::Span, } impl MlpBlock { @@ -71,24 +74,40 @@ impl MlpBlock { activation: candle_nn::Activation, vb: VarBuilder, ) -> Result<Self> { - let lin1 = candle_nn::linear(embedding_dim, mlp_dim, vb.pp("lin1"))?; - let lin2 = candle_nn::linear(mlp_dim, embedding_dim, vb.pp("lin2"))?; + let lin1 = linear(vb.pp("lin1"), embedding_dim, mlp_dim, true)?; + let lin2 = linear(vb.pp("lin2"), mlp_dim, embedding_dim, true)?; + let span = tracing::span!(tracing::Level::TRACE, "mlp-block"); Ok(Self { lin1, lin2, activation, + span, }) } } impl Module for MlpBlock { fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); xs.apply(&self.lin1)? .apply(&self.activation)? .apply(&self.lin2) } } +#[derive(Debug)] +pub struct Linear { + inner: candle_nn::Linear, + span: tracing::Span, +} + +impl Module for Linear { + fn forward(&self, x: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(x) + } +} + #[derive(Parser)] struct Args { #[arg(long)] @@ -109,10 +128,24 @@ struct Args { #[arg(long, default_value_t = 0.5)] point_y: f64, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, } pub fn main() -> anyhow::Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; let device = candle_examples::device(args.cpu)?; |