summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-transformers/src/models/t5.rs21
1 files changed, 17 insertions, 4 deletions
diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs
index ffa2764b..2b71fcda 100644
--- a/candle-transformers/src/models/t5.rs
+++ b/candle-transformers/src/models/t5.rs
@@ -321,6 +321,8 @@ struct T5Attention {
use_cache: bool,
kv_cache: Option<(Tensor, Tensor)>,
span: tracing::Span,
+ span_cache: tracing::Span,
+ span_mm: tracing::Span,
span_sm: tracing::Span,
}
@@ -360,6 +362,8 @@ impl T5Attention {
use_cache: cfg.use_cache && decoder,
kv_cache: None,
span: tracing::span!(tracing::Level::TRACE, "attention"),
+ span_cache: tracing::span!(tracing::Level::TRACE, "attention-cache"),
+ span_mm: tracing::span!(tracing::Level::TRACE, "attention-mm"),
span_sm: tracing::span!(tracing::Level::TRACE, "attention-sm"),
})
}
@@ -397,6 +401,7 @@ impl T5Attention {
.contiguous()?;
if self.use_cache {
+ let _enter = self.span_cache.enter();
if let Some((kv_cache_k, kv_cache_v)) = &self.kv_cache {
k = Tensor::cat(&[kv_cache_k, &k], 2)?.contiguous()?;
v = Tensor::cat(&[kv_cache_v, &v], 2)?.contiguous()?;
@@ -404,7 +409,10 @@ impl T5Attention {
self.kv_cache = Some((k.clone(), v.clone()));
};
// TODO: Use flash_attn.
- let scores = q.matmul(&k.t()?)?;
+ let scores = {
+ let _enter = self.span_mm.enter();
+ q.matmul(&k.t()?)?
+ };
let scores = match mask {
None => scores,
Some(mask) => masked_fill(
@@ -713,6 +721,7 @@ pub struct T5ForConditionalGeneration {
shared: Arc<Embedding>,
device: Device,
span_decode: tracing::Span,
+ span_decode_head: tracing::Span,
}
impl T5ForConditionalGeneration {
@@ -750,6 +759,7 @@ impl T5ForConditionalGeneration {
shared,
device: vb.device().clone(),
span_decode: tracing::span!(tracing::Level::TRACE, "decode"),
+ span_decode_head: tracing::span!(tracing::Level::TRACE, "decode-head"),
})
}
@@ -778,9 +788,12 @@ impl T5ForConditionalGeneration {
.narrow(1, decoder_output.dim(1)? - 1, 1)?
.squeeze(1)?)
* scaling_factor)?;
- let output = match self.lm_head {
- None => sequence_output.matmul(&self.shared.embeddings().t()?)?,
- Some(ref lm_head) => lm_head.forward(&sequence_output)?,
+ let output = {
+ let _enter = self.span_decode_head.enter();
+ match self.lm_head {
+ None => sequence_output.matmul(&self.shared.embeddings().t()?)?,
+ Some(ref lm_head) => lm_head.forward(&sequence_output)?,
+ }
};
// TODO: Rescale output before projecting on vocab? * (self.model_dim**-0.5)