summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/whisper
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-11-01 19:21:36 +0100
committerGitHub <noreply@github.com>2023-11-01 18:21:36 +0000
commit1704f1b3aec92b07dd805411fa8065eab55e4186 (patch)
tree3ee9c59fc63e8a2cca9b134483173807b301ec47 /candle-transformers/src/models/whisper
parent693fad511ca4a52040f5c5f4aae1ee8c43d544ed (diff)
downloadcandle-1704f1b3aec92b07dd805411fa8065eab55e4186.tar.gz
candle-1704f1b3aec92b07dd805411fa8065eab55e4186.tar.bz2
candle-1704f1b3aec92b07dd805411fa8065eab55e4186.zip
Consolidate the with-tracing usage. (#1234)
Diffstat (limited to 'candle-transformers/src/models/whisper')
-rw-r--r--candle-transformers/src/models/whisper/model.rs28
1 files changed, 1 insertions, 27 deletions
diff --git a/candle-transformers/src/models/whisper/model.rs b/candle-transformers/src/models/whisper/model.rs
index 6078944c..25454ba6 100644
--- a/candle-transformers/src/models/whisper/model.rs
+++ b/candle-transformers/src/models/whisper/model.rs
@@ -1,4 +1,5 @@
use super::Config;
+use crate::models::with_tracing::{linear, linear_no_bias, Linear};
use candle::{Device, IndexOp, Result, Tensor, D};
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
@@ -6,33 +7,6 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Em
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
Ok(Embedding::new(embeddings, hidden_size))
}
-//
-// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
-// model.
-#[derive(Debug, Clone)]
-pub struct Linear {
- inner: candle_nn::Linear,
- span: tracing::Span,
-}
-
-impl Linear {
- fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- self.inner.forward(x)
- }
-}
-
-fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
- let span = tracing::span!(tracing::Level::TRACE, "linear");
- let inner = candle_nn::linear(size1, size2, vb)?;
- Ok(Linear { inner, span })
-}
-
-fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
- let span = tracing::span!(tracing::Level::TRACE, "linear");
- let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
- Ok(Linear { inner, span })
-}
fn conv1d(
in_channels: usize,