diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-11-01 19:21:36 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-01 18:21:36 +0000 |
commit | 1704f1b3aec92b07dd805411fa8065eab55e4186 (patch) | |
tree | 3ee9c59fc63e8a2cca9b134483173807b301ec47 /candle-transformers/src/models/whisper | |
parent | 693fad511ca4a52040f5c5f4aae1ee8c43d544ed (diff) | |
download | candle-1704f1b3aec92b07dd805411fa8065eab55e4186.tar.gz candle-1704f1b3aec92b07dd805411fa8065eab55e4186.tar.bz2 candle-1704f1b3aec92b07dd805411fa8065eab55e4186.zip |
Consolidate the with-tracing usage. (#1234)
Diffstat (limited to 'candle-transformers/src/models/whisper')
-rw-r--r-- | candle-transformers/src/models/whisper/model.rs | 28 |
1 files changed, 1 insertions, 27 deletions
diff --git a/candle-transformers/src/models/whisper/model.rs b/candle-transformers/src/models/whisper/model.rs index 6078944c..25454ba6 100644 --- a/candle-transformers/src/models/whisper/model.rs +++ b/candle-transformers/src/models/whisper/model.rs @@ -1,4 +1,5 @@ use super::Config; +use crate::models::with_tracing::{linear, linear_no_bias, Linear}; use candle::{Device, IndexOp, Result, Tensor, D}; use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder}; @@ -6,33 +7,6 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Em let embeddings = vb.get((vocab_size, hidden_size), "weight")?; Ok(Embedding::new(embeddings, hidden_size)) } -// -// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting -// model. -#[derive(Debug, Clone)] -pub struct Linear { - inner: candle_nn::Linear, - span: tracing::Span, -} - -impl Linear { - fn forward(&self, x: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - self.inner.forward(x) - } -} - -fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> { - let span = tracing::span!(tracing::Level::TRACE, "linear"); - let inner = candle_nn::linear(size1, size2, vb)?; - Ok(Linear { inner, span }) -} - -fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> { - let span = tracing::span!(tracing::Level::TRACE, "linear"); - let inner = candle_nn::linear_no_bias(size1, size2, vb)?; - Ok(Linear { inner, span }) -} fn conv1d( in_channels: usize, |