summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-transformers/src/models/t5.rs40
1 files changed, 35 insertions, 5 deletions
diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs
index efb2819b..ffa2764b 100644
--- a/candle-transformers/src/models/t5.rs
+++ b/candle-transformers/src/models/t5.rs
@@ -2,11 +2,36 @@
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
use candle::{DType, Device, Module, Result, Tensor, D};
-use candle_nn::{embedding, Activation, Embedding, VarBuilder};
+use candle_nn::{Activation, VarBuilder};
use serde::Deserialize;
use std::sync::Arc;
#[derive(Debug)]
+struct Embedding {
+ inner: candle_nn::Embedding,
+ span: tracing::Span,
+}
+
+impl Embedding {
+ fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
+ let inner = candle_nn::embedding(d1, d2, vb)?;
+ let span = tracing::span!(tracing::Level::TRACE, "embedding");
+ Ok(Self { inner, span })
+ }
+
+ fn embeddings(&self) -> &Tensor {
+ self.inner.embeddings()
+ }
+}
+
+impl Module for Embedding {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(xs)
+ }
+}
+
+#[derive(Debug)]
struct Linear {
inner: candle_nn::Linear,
span: tracing::Span,
@@ -296,6 +321,7 @@ struct T5Attention {
use_cache: bool,
kv_cache: Option<(Tensor, Tensor)>,
span: tracing::Span,
+ span_sm: tracing::Span,
}
impl T5Attention {
@@ -311,7 +337,7 @@ impl T5Attention {
let v = Linear::new(cfg.d_model, inner_dim, vb.pp("v"))?;
let o = Linear::new(inner_dim, cfg.d_model, vb.pp("o"))?;
let relative_attention_bias = if has_relative_attention_bias {
- let emb = embedding(
+ let emb = Embedding::new(
cfg.relative_attention_num_buckets,
cfg.num_heads,
vb.pp("relative_attention_bias"),
@@ -334,6 +360,7 @@ impl T5Attention {
use_cache: cfg.use_cache && decoder,
kv_cache: None,
span: tracing::span!(tracing::Level::TRACE, "attention"),
+ span_sm: tracing::span!(tracing::Level::TRACE, "attention-sm"),
})
}
@@ -449,7 +476,10 @@ impl T5Attention {
},
};
- let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?;
+ let attn_weights = {
+ let _enter = self.span_sm.enter();
+ candle_nn::ops::softmax(&scores, D::Minus1)?
+ };
let attn_output = attn_weights.matmul(&v)?;
let attn_output = attn_output
.transpose(1, 2)?
@@ -653,7 +683,7 @@ pub struct T5EncoderModel {
impl T5EncoderModel {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
- let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
+ let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
let shared = Arc::new(shared);
let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
Ok(Self {
@@ -689,7 +719,7 @@ impl T5ForConditionalGeneration {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
assert!(cfg.is_encoder_decoder);
let d_model = cfg.d_model;
- let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
+ let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
let shared = Arc::new(shared);
let mut encoder_cfg = cfg.clone();