summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/with_tracing.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/with_tracing.rs')
-rw-r--r--candle-transformers/src/models/with_tracing.rs78
1 files changed, 78 insertions, 0 deletions
diff --git a/candle-transformers/src/models/with_tracing.rs b/candle-transformers/src/models/with_tracing.rs
new file mode 100644
index 00000000..0a2d65b9
--- /dev/null
+++ b/candle-transformers/src/models/with_tracing.rs
@@ -0,0 +1,78 @@
+use candle::{Module, Result, Tensor};
+use candle_nn::VarBuilder;
+
+#[derive(Debug)]
+pub struct Embedding {
+ inner: candle_nn::Embedding,
+ span: tracing::Span,
+}
+
+impl Embedding {
+ pub fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
+ let inner = candle_nn::embedding(d1, d2, vb)?;
+ let span = tracing::span!(tracing::Level::TRACE, "embedding");
+ Ok(Self { inner, span })
+ }
+
+ pub fn embeddings(&self) -> &Tensor {
+ self.inner.embeddings()
+ }
+}
+
+impl Module for Embedding {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(xs)
+ }
+}
+
+#[derive(Debug)]
+pub struct Linear {
+ inner: candle_nn::Linear,
+ span: tracing::Span,
+}
+
+pub fn linear(d1: usize, d2: usize, vb: VarBuilder) -> Result<Linear> {
+ let inner = candle_nn::linear(d1, d2, vb)?;
+ let span = tracing::span!(tracing::Level::TRACE, "linear");
+ Ok(Linear { inner, span })
+}
+
+pub fn linear_no_bias(d1: usize, d2: usize, vb: VarBuilder) -> Result<Linear> {
+ let inner = candle_nn::linear_no_bias(d1, d2, vb)?;
+ let span = tracing::span!(tracing::Level::TRACE, "linear");
+ Ok(Linear { inner, span })
+}
+
+impl Module for Linear {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(xs)
+ }
+}
+
+// Wrap the conv2d op to provide some tracing.
+#[derive(Debug)]
+pub struct Conv2d {
+ inner: candle_nn::Conv2d,
+ span: tracing::Span,
+}
+
+impl Conv2d {
+ pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(x)
+ }
+}
+
+pub fn conv2d(
+ in_channels: usize,
+ out_channels: usize,
+ kernel_size: usize,
+ cfg: candle_nn::Conv2dConfig,
+ vs: candle_nn::VarBuilder,
+) -> Result<Conv2d> {
+ let span = tracing::span!(tracing::Level::TRACE, "conv2d");
+ let inner = candle_nn::conv2d(in_channels, out_channels, kernel_size, cfg, vs)?;
+ Ok(Conv2d { inner, span })
+}