diff options
Diffstat (limited to 'candle-transformers/src/models/quantized_metavoice.rs')
-rw-r--r-- | candle-transformers/src/models/quantized_metavoice.rs | 18 |
1 files changed, 17 insertions, 1 deletions
diff --git a/candle-transformers/src/models/quantized_metavoice.rs b/candle-transformers/src/models/quantized_metavoice.rs index 16545150..84c0388c 100644 --- a/candle-transformers/src/models/quantized_metavoice.rs +++ b/candle-transformers/src/models/quantized_metavoice.rs @@ -14,6 +14,7 @@ pub mod transformer { w1: Linear, w2: Linear, w3: Linear, + span: tracing::Span, } impl FeedForward { @@ -22,12 +23,18 @@ pub mod transformer { let w1 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w1"))?; let w2 = linear_b(i_size, cfg.dim, false, vb.pp("w2"))?; let w3 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w3"))?; - Ok(Self { w1, w2, w3 }) + Ok(Self { + w1, + w2, + w3, + span: tracing::span!(tracing::Level::TRACE, "feed-forward"), + }) } } impl Module for FeedForward { fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); let swiglu = (candle_nn::ops::silu(&xs.apply(&self.w1)?)? * xs.apply(&self.w3))?; swiglu.apply(&self.w2) } @@ -43,6 +50,7 @@ pub mod transformer { head_dim: usize, n_head: usize, kv_cache: Option<(Tensor, Tensor)>, + span: tracing::Span, } impl Attention { @@ -61,10 +69,12 @@ pub mod transformer { head_dim, n_head: cfg.n_head, kv_cache: None, + span: tracing::span!(tracing::Level::TRACE, "attention"), }) } fn forward(&mut self, xs: &Tensor, _pos: usize, mask: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); let (b_sz, seqlen, _) = xs.dims3()?; let qkv = xs.apply(&self.wqkv)?; @@ -118,6 +128,7 @@ pub mod transformer { feed_forward: FeedForward, ffn_norm: RmsNorm, attention_norm: RmsNorm, + span: tracing::Span, } impl Block { @@ -131,10 +142,12 @@ pub mod transformer { feed_forward, ffn_norm, attention_norm, + span: tracing::span!(tracing::Level::TRACE, "block"), }) } fn forward(&mut self, xs: &Tensor, pos: usize, mask: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); let hs = xs.apply(&self.attention_norm)?; let hs = (xs + self.attention.forward(&hs, pos, mask))?; &hs + hs.apply(&self.ffn_norm)?.apply(&self.feed_forward) @@ -154,6 +167,7 @@ pub mod transformer { norm: RmsNorm, output: Linear, spk_cond_mask: Tensor, + span: tracing::Span, } impl Model { @@ -189,6 +203,7 @@ pub mod transformer { norm, output, spk_cond_mask, + span: tracing::span!(tracing::Level::TRACE, "qtransformer"), }) } @@ -199,6 +214,7 @@ pub mod transformer { } pub fn forward(&mut self, xs: &Tensor, spk_emb: &Tensor, pos: usize) -> Result<Tensor> { + let _enter = self.span.enter(); let (_b_sz, seqlen) = xs.dims2()?; let mask: Vec<_> = (0..seqlen) .flat_map(|i| (0..seqlen).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. })) |