summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/quantized_stable_lm.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-10 08:09:25 +0200
committerGitHub <noreply@github.com>2023-10-10 08:09:25 +0200
commitbc3351bce4ce0ad24c69f872ffd51dc829fe88c8 (patch)
tree1ef7c95a493f8f3c122b3b7716105cb0dfae0053 /candle-transformers/src/models/quantized_stable_lm.rs
parentb34d7f0248d88d4cccbd0509a1865196b61de5d8 (diff)
downloadcandle-bc3351bce4ce0ad24c69f872ffd51dc829fe88c8.tar.gz
candle-bc3351bce4ce0ad24c69f872ffd51dc829fe88c8.tar.bz2
candle-bc3351bce4ce0ad24c69f872ffd51dc829fe88c8.zip
Tracing for StableLM and quantized StableLM. (#1068)
Diffstat (limited to 'candle-transformers/src/models/quantized_stable_lm.rs')
-rw-r--r--candle-transformers/src/models/quantized_stable_lm.rs12
1 files changed, 12 insertions, 0 deletions
diff --git a/candle-transformers/src/models/quantized_stable_lm.rs b/candle-transformers/src/models/quantized_stable_lm.rs
index 304e91ee..d117e4b3 100644
--- a/candle-transformers/src/models/quantized_stable_lm.rs
+++ b/candle-transformers/src/models/quantized_stable_lm.rs
@@ -14,6 +14,7 @@ struct MLP {
up_proj: Linear,
down_proj: Linear,
act_fn: Activation,
+ span: tracing::Span,
}
impl MLP {
@@ -28,12 +29,14 @@ impl MLP {
up_proj,
down_proj,
act_fn: cfg.hidden_act,
+ span: tracing::span!(tracing::Level::TRACE, "mlp"),
})
}
}
impl Module for MLP {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
let rhs = xs.apply(&self.up_proj)?;
(lhs * rhs)?.apply(&self.down_proj)
@@ -55,6 +58,7 @@ struct Attention {
kv_cache: Option<(Tensor, Tensor)>,
use_cache: bool,
rotary_ndims: usize,
+ span: tracing::Span,
}
impl Attention {
@@ -81,6 +85,7 @@ impl Attention {
kv_cache: None,
use_cache: cfg.use_cache,
rotary_ndims: cfg.rotary_ndims(),
+ span: tracing::span!(tracing::Level::TRACE, "attn"),
})
}
@@ -102,6 +107,7 @@ impl Attention {
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
+ let _enter = self.span.enter();
let (b_sz, q_len, _) = xs.dims3()?;
let query_states = self.q_proj.forward(xs)?;
@@ -168,6 +174,7 @@ struct DecoderLayer {
mlp: MLP,
input_layernorm: LayerNorm,
post_attention_layernorm: LayerNorm,
+ span: tracing::Span,
}
impl DecoderLayer {
@@ -185,6 +192,7 @@ impl DecoderLayer {
mlp,
input_layernorm,
post_attention_layernorm,
+ span: tracing::span!(tracing::Level::TRACE, "layer"),
})
}
@@ -194,6 +202,7 @@ impl DecoderLayer {
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
+ let _enter = self.span.enter();
let residual = xs;
let xs = self.input_layernorm.forward(xs)?;
let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
@@ -211,6 +220,7 @@ pub struct Model {
norm: LayerNorm,
lm_head: Linear,
device: Device,
+ span: tracing::Span,
}
impl Model {
@@ -233,6 +243,7 @@ impl Model {
norm,
lm_head,
device: vb.device().clone(),
+ span: tracing::span!(tracing::Level::TRACE, "model"),
})
}
@@ -258,6 +269,7 @@ impl Model {
}
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
+ let _enter = self.span.enter();
let (b_size, seq_len) = input_ids.dims2()?;
let attention_mask = if seq_len <= 1 {
None