summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/t5.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/t5.rs')
-rw-r--r--candle-transformers/src/models/t5.rs181
1 files changed, 153 insertions, 28 deletions
diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs
index de7de496..c35dea0b 100644
--- a/candle-transformers/src/models/t5.rs
+++ b/candle-transformers/src/models/t5.rs
@@ -18,6 +18,21 @@ fn default_use_cache() -> bool {
true
}
+fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
+ let mask: Vec<_> = (0..size)
+ .flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
+ .collect();
+ let result = Tensor::from_slice(&mask, (size, size), device)?;
+ Ok(result)
+}
+
+fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
+ let shape = mask.shape();
+ let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
+ let m = mask.where_cond(&on_true, on_false)?;
+ Ok(m)
+}
+
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
vocab_size: usize,
@@ -40,8 +55,8 @@ pub struct Config {
is_encoder_decoder: bool,
#[serde(default = "default_use_cache")]
use_cache: bool,
- pad_token_id: usize,
- eos_token_id: usize,
+ pub pad_token_id: usize,
+ pub eos_token_id: usize,
}
impl Default for Config {
@@ -233,13 +248,13 @@ struct T5Attention {
}
impl T5Attention {
- fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
let inner_dim = cfg.num_heads * cfg.d_kv;
let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?;
let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?;
let v = linear_no_bias(cfg.d_model, inner_dim, vb.pp("v"))?;
let o = linear_no_bias(inner_dim, cfg.d_model, vb.pp("o"))?;
- let relative_attention_bias = if h {
+ let relative_attention_bias = if has_relative_attention_bias {
let emb = embedding(
cfg.relative_attention_num_buckets,
cfg.num_heads,
@@ -267,26 +282,46 @@ impl T5Attention {
&self,
xs: &Tensor,
position_bias: Option<&Tensor>,
+ key_value_states: Option<&Tensor>,
+ mask: Option<&Tensor>,
) -> Result<(Tensor, Option<Tensor>)> {
- // TODO: Apply the mask(s)?
+ // Performs Self-attention (if key_value_states is None) or attention
+ // over source sentence (provided by key_value_states).
// TODO: kv caching.
- let (b_sz, seq_len) = (xs.dim(0)?, xs.dim(1)?);
+ let kv_input = match key_value_states {
+ None => xs,
+ Some(key_value_states) => key_value_states,
+ };
+ let (b_sz, q_len) = (xs.dim(0)?, xs.dim(1)?);
+ let kv_len = kv_input.dim(1)?;
let q = self.q.forward(xs)?;
- let k = self.k.forward(xs)?;
- let v = self.v.forward(xs)?;
+ let k = self.k.forward(kv_input)?;
+ let v = self.v.forward(kv_input)?;
let q = q
- .reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
+ .reshape((b_sz, q_len, self.n_heads, self.d_kv))?
.transpose(1, 2)?
.contiguous()?;
let k = k
- .reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
+ .reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
.transpose(1, 2)?
.contiguous()?;
let v = v
- .reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
+ .reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
.transpose(1, 2)?
.contiguous()?;
+ // TODO: Use flash_attn.
let scores = q.matmul(&k.t()?)?;
+ let scores = match mask {
+ None => scores,
+ Some(mask) => masked_fill(
+ &scores,
+ &mask
+ .unsqueeze(0)?
+ .unsqueeze(0)?
+ .repeat((b_sz, self.n_heads))?,
+ f32::NEG_INFINITY,
+ )?,
+ };
let (scores, position_bias) = match position_bias {
Some(position_bias) => (
@@ -296,14 +331,12 @@ impl T5Attention {
None => match &self.relative_attention_bias {
None => (scores, None),
Some(relative_attention_bias) => {
- let query_length = seq_len;
- let key_length = seq_len;
// This only handles the bidirectional case.
let num_buckets = self.relative_attention_num_buckets as u32 / 2;
let max_exact = num_buckets / 2;
- let relative_position = (0..query_length as u32)
+ let relative_position = (0..q_len as u32)
.map(|i| {
- (0..key_length as u32)
+ (0..kv_len as u32)
.map(|j| {
if i < j {
if j - i < max_exact {
@@ -348,7 +381,7 @@ impl T5Attention {
let attn_output = attn_weights.matmul(&v)?;
let attn_output = attn_output
.transpose(1, 2)?
- .reshape((b_sz, seq_len, self.inner_dim))?;
+ .reshape((b_sz, q_len, self.inner_dim))?;
let attn_output = self.o.forward(&attn_output)?;
Ok((attn_output, position_bias))
}
@@ -375,24 +408,49 @@ impl T5LayerSelfAttention {
&self,
xs: &Tensor,
position_bias: Option<&Tensor>,
+ mask: Option<&Tensor>,
) -> Result<(Tensor, Option<Tensor>)> {
let normed_xs = self.layer_norm.forward(xs)?;
- let (ys, position_bias) = self.self_attention.forward(&normed_xs, position_bias)?;
+ let (ys, position_bias) =
+ self.self_attention
+ .forward(&normed_xs, position_bias, None, mask)?;
let ys = (xs + ys)?;
Ok((ys, position_bias))
}
}
#[derive(Debug)]
-struct T5LayerCrossAttention {}
+struct T5LayerCrossAttention {
+ cross_attention: T5Attention,
+ layer_norm: T5LayerNorm,
+}
impl T5LayerCrossAttention {
- fn load(_vb: VarBuilder, _cfg: &Config) -> Result<Self> {
- todo!()
+ fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let cross_attention = T5Attention::load(false, vb.pp("EncDecAttention"), cfg)?;
+ let layer_norm =
+ T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
+ Ok(Self {
+ cross_attention,
+ layer_norm,
+ })
}
- fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
- todo!()
+ fn forward(
+ &self,
+ hidden_states: &Tensor,
+ position_bias: Option<&Tensor>,
+ key_value_states: &Tensor,
+ ) -> Result<(Tensor, Option<Tensor>)> {
+ let normed_hidden_states = self.layer_norm.forward(hidden_states)?;
+ let (ys, position_bias) = self.cross_attention.forward(
+ &normed_hidden_states,
+ position_bias,
+ Some(key_value_states),
+ None,
+ )?;
+ let ys = (hidden_states + ys)?;
+ Ok((ys, position_bias))
}
}
@@ -425,11 +483,17 @@ impl T5Block {
&self,
xs: &Tensor,
position_bias: Option<&Tensor>,
+ encoder_hidden_states: Option<&Tensor>,
) -> Result<(Tensor, Option<Tensor>)> {
- let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias)?;
+ // TODO: Cache masks
+ let mask = match self.cross_attn.is_some() {
+ true => Some(get_mask(xs.dim(1)?, xs.device())?),
+ false => None,
+ };
+ let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias, mask.as_ref())?;
// TODO: clamp for f16?
if let Some(cross_attn) = &self.cross_attn {
- xs = cross_attn.forward(&xs)?;
+ (xs, _) = cross_attn.forward(&xs, None, encoder_hidden_states.unwrap())?;
// TODO: clamp for f16?
}
let xs = self.ff.forward(&xs)?;
@@ -462,13 +526,20 @@ impl T5Stack {
})
}
- fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
+ fn forward(
+ &self,
+ input_ids: &Tensor,
+ encoder_hidden_states: Option<&Tensor>,
+ ) -> Result<Tensor> {
let input_embeds = self.shared.as_ref().forward(input_ids)?;
let mut hidden_states = input_embeds;
let mut position_bias = None;
for block in self.block.iter() {
- (hidden_states, position_bias) =
- block.forward(&hidden_states, position_bias.as_ref())?
+ (hidden_states, position_bias) = block.forward(
+ &hidden_states,
+ position_bias.as_ref(),
+ encoder_hidden_states,
+ )?
}
self.final_layer_norm.forward(&hidden_states)
}
@@ -492,7 +563,61 @@ impl T5EncoderModel {
}
pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
- self.encoder.forward(input_ids)
+ self.encoder.forward(input_ids, None)
+ }
+
+ pub fn device(&self) -> &Device {
+ &self.device
+ }
+}
+
+#[derive(Debug)]
+pub struct T5ForConditionalGeneration {
+ encoder: T5Stack,
+ decoder: T5Stack,
+ shared: Arc<Embedding>,
+ device: Device,
+}
+
+impl T5ForConditionalGeneration {
+ pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ assert!(cfg.is_encoder_decoder);
+ let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
+ let shared = Arc::new(shared);
+
+ let mut encoder_cfg = cfg.clone();
+ encoder_cfg.is_decoder = false;
+ encoder_cfg.use_cache = false;
+ encoder_cfg.is_encoder_decoder = false;
+ let encoder = T5Stack::load(vb.pp("encoder"), &shared, &encoder_cfg)?;
+
+ let mut decoder_cfg = cfg.clone();
+ decoder_cfg.is_decoder = true;
+ decoder_cfg.is_encoder_decoder = false;
+ decoder_cfg.num_layers = cfg.num_decoder_layers.unwrap_or(cfg.num_layers);
+ let decoder = T5Stack::load(vb.pp("decoder"), &shared, &decoder_cfg)?;
+
+ Ok(Self {
+ encoder,
+ decoder,
+ shared,
+ device: vb.device().clone(),
+ })
+ }
+
+ pub fn forward(&self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor> {
+ let encoder_output = self.encoder.forward(input_ids, None)?;
+ let decoder_output = self
+ .decoder
+ .forward(decoder_input_ids, Some(&encoder_output))?;
+ let sequence_output = decoder_output
+ .narrow(1, decoder_output.dim(1)? - 1, 1)?
+ .squeeze(1)?;
+ // TODO: check cfg.tie_word_embeddings to load from model instead.
+ let lm_head_weights = self.shared.embeddings().t()?;
+ let output = sequence_output.matmul(&lm_head_weights)?;
+ // TODO: Rescale output before projecting on vocab? * (self.model_dim**-0.5)
+ Ok(output)
}
pub fn device(&self) -> &Device {