summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything/main.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-08 15:31:29 +0100
committerGitHub <noreply@github.com>2023-09-08 15:31:29 +0100
commit158ff3c609b22ed998dea5283738cc1ed13aa592 (patch)
treeab8f89575a1e9322898147132dc44a676b3e991a /candle-examples/examples/segment-anything/main.rs
parente5703d2f56ce24652e7ae85dc74484681e4dbcb9 (diff)
downloadcandle-158ff3c609b22ed998dea5283738cc1ed13aa592.tar.gz
candle-158ff3c609b22ed998dea5283738cc1ed13aa592.tar.bz2
candle-158ff3c609b22ed998dea5283738cc1ed13aa592.zip
Add tracing to segment-anything (#777)
* Tracing support for segment-anything. * More tracing. * Handle the empty slice case.
Diffstat (limited to 'candle-examples/examples/segment-anything/main.rs')
-rw-r--r--candle-examples/examples/segment-anything/main.rs47
1 files changed, 40 insertions, 7 deletions
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs
index c5095c0e..a749ba2a 100644
--- a/candle-examples/examples/segment-anything/main.rs
+++ b/candle-examples/examples/segment-anything/main.rs
@@ -14,15 +14,17 @@ pub mod model_sam;
pub mod model_transformer;
use candle::{DType, Result, Tensor};
-use candle_nn::{Linear, Module, VarBuilder};
+use candle_nn::{Module, VarBuilder};
use clap::Parser;
pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
- if bias {
- candle_nn::linear(in_dim, out_dim, vb)
+ let inner = if bias {
+ candle_nn::linear(in_dim, out_dim, vb)?
} else {
- candle_nn::linear_no_bias(in_dim, out_dim, vb)
- }
+ candle_nn::linear_no_bias(in_dim, out_dim, vb)?
+ };
+ let span = tracing::span!(tracing::Level::TRACE, "linear");
+ Ok(Linear { inner, span })
}
#[derive(Debug)]
@@ -62,6 +64,7 @@ pub struct MlpBlock {
lin1: Linear,
lin2: Linear,
activation: candle_nn::Activation,
+ span: tracing::Span,
}
impl MlpBlock {
@@ -71,24 +74,40 @@ impl MlpBlock {
activation: candle_nn::Activation,
vb: VarBuilder,
) -> Result<Self> {
- let lin1 = candle_nn::linear(embedding_dim, mlp_dim, vb.pp("lin1"))?;
- let lin2 = candle_nn::linear(mlp_dim, embedding_dim, vb.pp("lin2"))?;
+ let lin1 = linear(vb.pp("lin1"), embedding_dim, mlp_dim, true)?;
+ let lin2 = linear(vb.pp("lin2"), mlp_dim, embedding_dim, true)?;
+ let span = tracing::span!(tracing::Level::TRACE, "mlp-block");
Ok(Self {
lin1,
lin2,
activation,
+ span,
})
}
}
impl Module for MlpBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
xs.apply(&self.lin1)?
.apply(&self.activation)?
.apply(&self.lin2)
}
}
+#[derive(Debug)]
+pub struct Linear {
+ inner: candle_nn::Linear,
+ span: tracing::Span,
+}
+
+impl Module for Linear {
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(x)
+ }
+}
+
#[derive(Parser)]
struct Args {
#[arg(long)]
@@ -109,10 +128,24 @@ struct Args {
#[arg(long, default_value_t = 0.5)]
point_y: f64,
+
+ /// Enable tracing (generates a trace-timestamp.json file).
+ #[arg(long)]
+ tracing: bool,
}
pub fn main() -> anyhow::Result<()> {
+ use tracing_chrome::ChromeLayerBuilder;
+ use tracing_subscriber::prelude::*;
+
let args = Args::parse();
+ let _guard = if args.tracing {
+ let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
+ tracing_subscriber::registry().with(chrome_layer).init();
+ Some(guard)
+ } else {
+ None
+ };
let device = candle_examples::device(args.cpu)?;