summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/segment_anything/mod.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-11-01 19:21:36 +0100
committerGitHub <noreply@github.com>2023-11-01 18:21:36 +0000
commit1704f1b3aec92b07dd805411fa8065eab55e4186 (patch)
tree3ee9c59fc63e8a2cca9b134483173807b301ec47 /candle-transformers/src/models/segment_anything/mod.rs
parent693fad511ca4a52040f5c5f4aae1ee8c43d544ed (diff)
downloadcandle-1704f1b3aec92b07dd805411fa8065eab55e4186.tar.gz
candle-1704f1b3aec92b07dd805411fa8065eab55e4186.tar.bz2
candle-1704f1b3aec92b07dd805411fa8065eab55e4186.zip
Consolidate the with-tracing usage. (#1234)
Diffstat (limited to 'candle-transformers/src/models/segment_anything/mod.rs')
-rw-r--r--candle-transformers/src/models/segment_anything/mod.rs24
1 files changed, 5 insertions, 19 deletions
diff --git a/candle-transformers/src/models/segment_anything/mod.rs b/candle-transformers/src/models/segment_anything/mod.rs
index c29db70a..c54493d2 100644
--- a/candle-transformers/src/models/segment_anything/mod.rs
+++ b/candle-transformers/src/models/segment_anything/mod.rs
@@ -1,3 +1,4 @@
+pub use crate::models::with_tracing::Linear;
use candle::{Result, Tensor};
use candle_nn::{Module, VarBuilder};
@@ -9,13 +10,11 @@ pub mod tiny_vit;
pub mod transformer;
pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
- let inner = if bias {
- candle_nn::linear(in_dim, out_dim, vb)?
+ if bias {
+ crate::models::with_tracing::linear(in_dim, out_dim, vb)
} else {
- candle_nn::linear_no_bias(in_dim, out_dim, vb)?
- };
- let span = tracing::span!(tracing::Level::TRACE, "linear");
- Ok(Linear { inner, span })
+ crate::models::with_tracing::linear_no_bias(in_dim, out_dim, vb)
+ }
}
#[derive(Debug)]
@@ -85,16 +84,3 @@ impl Module for MlpBlock {
.apply(&self.lin2)
}
}
-
-#[derive(Debug)]
-pub struct Linear {
- inner: candle_nn::Linear,
- span: tracing::Span,
-}
-
-impl Module for Linear {
- fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- self.inner.forward(x)
- }
-}