summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/quantized_metavoice.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/quantized_metavoice.rs')
-rw-r--r--candle-transformers/src/models/quantized_metavoice.rs18
1 files changed, 17 insertions, 1 deletions
diff --git a/candle-transformers/src/models/quantized_metavoice.rs b/candle-transformers/src/models/quantized_metavoice.rs
index 16545150..84c0388c 100644
--- a/candle-transformers/src/models/quantized_metavoice.rs
+++ b/candle-transformers/src/models/quantized_metavoice.rs
@@ -14,6 +14,7 @@ pub mod transformer {
w1: Linear,
w2: Linear,
w3: Linear,
+ span: tracing::Span,
}
impl FeedForward {
@@ -22,12 +23,18 @@ pub mod transformer {
let w1 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w1"))?;
let w2 = linear_b(i_size, cfg.dim, false, vb.pp("w2"))?;
let w3 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w3"))?;
- Ok(Self { w1, w2, w3 })
+ Ok(Self {
+ w1,
+ w2,
+ w3,
+ span: tracing::span!(tracing::Level::TRACE, "feed-forward"),
+ })
}
}
impl Module for FeedForward {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
let swiglu = (candle_nn::ops::silu(&xs.apply(&self.w1)?)? * xs.apply(&self.w3))?;
swiglu.apply(&self.w2)
}
@@ -43,6 +50,7 @@ pub mod transformer {
head_dim: usize,
n_head: usize,
kv_cache: Option<(Tensor, Tensor)>,
+ span: tracing::Span,
}
impl Attention {
@@ -61,10 +69,12 @@ pub mod transformer {
head_dim,
n_head: cfg.n_head,
kv_cache: None,
+ span: tracing::span!(tracing::Level::TRACE, "attention"),
})
}
fn forward(&mut self, xs: &Tensor, _pos: usize, mask: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
let (b_sz, seqlen, _) = xs.dims3()?;
let qkv = xs.apply(&self.wqkv)?;
@@ -118,6 +128,7 @@ pub mod transformer {
feed_forward: FeedForward,
ffn_norm: RmsNorm,
attention_norm: RmsNorm,
+ span: tracing::Span,
}
impl Block {
@@ -131,10 +142,12 @@ pub mod transformer {
feed_forward,
ffn_norm,
attention_norm,
+ span: tracing::span!(tracing::Level::TRACE, "block"),
})
}
fn forward(&mut self, xs: &Tensor, pos: usize, mask: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
let hs = xs.apply(&self.attention_norm)?;
let hs = (xs + self.attention.forward(&hs, pos, mask))?;
&hs + hs.apply(&self.ffn_norm)?.apply(&self.feed_forward)
@@ -154,6 +167,7 @@ pub mod transformer {
norm: RmsNorm,
output: Linear,
spk_cond_mask: Tensor,
+ span: tracing::Span,
}
impl Model {
@@ -189,6 +203,7 @@ pub mod transformer {
norm,
output,
spk_cond_mask,
+ span: tracing::span!(tracing::Level::TRACE, "qtransformer"),
})
}
@@ -199,6 +214,7 @@ pub mod transformer {
}
pub fn forward(&mut self, xs: &Tensor, spk_emb: &Tensor, pos: usize) -> Result<Tensor> {
+ let _enter = self.span.enter();
let (_b_sz, seqlen) = xs.dims2()?;
let mask: Vec<_> = (0..seqlen)
.flat_map(|i| (0..seqlen).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))