diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-10 08:09:25 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-10 08:09:25 +0200 |
commit | bc3351bce4ce0ad24c69f872ffd51dc829fe88c8 (patch) | |
tree | 1ef7c95a493f8f3c122b3b7716105cb0dfae0053 /candle-transformers/src/models/quantized_stable_lm.rs | |
parent | b34d7f0248d88d4cccbd0509a1865196b61de5d8 (diff) | |
download | candle-bc3351bce4ce0ad24c69f872ffd51dc829fe88c8.tar.gz candle-bc3351bce4ce0ad24c69f872ffd51dc829fe88c8.tar.bz2 candle-bc3351bce4ce0ad24c69f872ffd51dc829fe88c8.zip |
Tracing for StableLM and quantized StableLM. (#1068)
Diffstat (limited to 'candle-transformers/src/models/quantized_stable_lm.rs')
-rw-r--r-- | candle-transformers/src/models/quantized_stable_lm.rs | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/candle-transformers/src/models/quantized_stable_lm.rs b/candle-transformers/src/models/quantized_stable_lm.rs index 304e91ee..d117e4b3 100644 --- a/candle-transformers/src/models/quantized_stable_lm.rs +++ b/candle-transformers/src/models/quantized_stable_lm.rs @@ -14,6 +14,7 @@ struct MLP { up_proj: Linear, down_proj: Linear, act_fn: Activation, + span: tracing::Span, } impl MLP { @@ -28,12 +29,14 @@ impl MLP { up_proj, down_proj, act_fn: cfg.hidden_act, + span: tracing::span!(tracing::Level::TRACE, "mlp"), }) } } impl Module for MLP { fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?; let rhs = xs.apply(&self.up_proj)?; (lhs * rhs)?.apply(&self.down_proj) @@ -55,6 +58,7 @@ struct Attention { kv_cache: Option<(Tensor, Tensor)>, use_cache: bool, rotary_ndims: usize, + span: tracing::Span, } impl Attention { @@ -81,6 +85,7 @@ impl Attention { kv_cache: None, use_cache: cfg.use_cache, rotary_ndims: cfg.rotary_ndims(), + span: tracing::span!(tracing::Level::TRACE, "attn"), }) } @@ -102,6 +107,7 @@ impl Attention { attention_mask: Option<&Tensor>, seqlen_offset: usize, ) -> Result<Tensor> { + let _enter = self.span.enter(); let (b_sz, q_len, _) = xs.dims3()?; let query_states = self.q_proj.forward(xs)?; @@ -168,6 +174,7 @@ struct DecoderLayer { mlp: MLP, input_layernorm: LayerNorm, post_attention_layernorm: LayerNorm, + span: tracing::Span, } impl DecoderLayer { @@ -185,6 +192,7 @@ impl DecoderLayer { mlp, input_layernorm, post_attention_layernorm, + span: tracing::span!(tracing::Level::TRACE, "layer"), }) } @@ -194,6 +202,7 @@ impl DecoderLayer { attention_mask: Option<&Tensor>, seqlen_offset: usize, ) -> Result<Tensor> { + let _enter = self.span.enter(); let residual = xs; let xs = self.input_layernorm.forward(xs)?; let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?; @@ -211,6 +220,7 @@ pub struct Model { norm: LayerNorm, lm_head: Linear, device: Device, + span: tracing::Span, } impl Model { @@ -233,6 +243,7 @@ impl Model { norm, lm_head, device: vb.device().clone(), + span: tracing::span!(tracing::Level::TRACE, "model"), }) } @@ -258,6 +269,7 @@ impl Model { } pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> { + let _enter = self.span.enter(); let (b_size, seq_len) = input_ids.dims2()?; let attention_mask = if seq_len <= 1 { None |