diff options
Diffstat (limited to 'candle-transformers/src/models/t5.rs')
-rw-r--r-- | candle-transformers/src/models/t5.rs | 181 |
1 files changed, 153 insertions, 28 deletions
diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index de7de496..c35dea0b 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -18,6 +18,21 @@ fn default_use_cache() -> bool { true } +fn get_mask(size: usize, device: &Device) -> Result<Tensor> { + let mask: Vec<_> = (0..size) + .flat_map(|i| (0..size).map(move |j| u8::from(j > i))) + .collect(); + let result = Tensor::from_slice(&mask, (size, size), device)?; + Ok(result) +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { + let shape = mask.shape(); + let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true, on_false)?; + Ok(m) +} + #[derive(Debug, Clone, PartialEq, Deserialize)] pub struct Config { vocab_size: usize, @@ -40,8 +55,8 @@ pub struct Config { is_encoder_decoder: bool, #[serde(default = "default_use_cache")] use_cache: bool, - pad_token_id: usize, - eos_token_id: usize, + pub pad_token_id: usize, + pub eos_token_id: usize, } impl Default for Config { @@ -233,13 +248,13 @@ struct T5Attention { } impl T5Attention { - fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> { + fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> { let inner_dim = cfg.num_heads * cfg.d_kv; let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?; let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?; let v = linear_no_bias(cfg.d_model, inner_dim, vb.pp("v"))?; let o = linear_no_bias(inner_dim, cfg.d_model, vb.pp("o"))?; - let relative_attention_bias = if h { + let relative_attention_bias = if has_relative_attention_bias { let emb = embedding( cfg.relative_attention_num_buckets, cfg.num_heads, @@ -267,26 +282,46 @@ impl T5Attention { &self, xs: &Tensor, position_bias: Option<&Tensor>, + key_value_states: Option<&Tensor>, + mask: Option<&Tensor>, ) -> Result<(Tensor, Option<Tensor>)> { - // TODO: Apply the mask(s)? + // Performs Self-attention (if key_value_states is None) or attention + // over source sentence (provided by key_value_states). // TODO: kv caching. - let (b_sz, seq_len) = (xs.dim(0)?, xs.dim(1)?); + let kv_input = match key_value_states { + None => xs, + Some(key_value_states) => key_value_states, + }; + let (b_sz, q_len) = (xs.dim(0)?, xs.dim(1)?); + let kv_len = kv_input.dim(1)?; let q = self.q.forward(xs)?; - let k = self.k.forward(xs)?; - let v = self.v.forward(xs)?; + let k = self.k.forward(kv_input)?; + let v = self.v.forward(kv_input)?; let q = q - .reshape((b_sz, seq_len, self.n_heads, self.d_kv))? + .reshape((b_sz, q_len, self.n_heads, self.d_kv))? .transpose(1, 2)? .contiguous()?; let k = k - .reshape((b_sz, seq_len, self.n_heads, self.d_kv))? + .reshape((b_sz, kv_len, self.n_heads, self.d_kv))? .transpose(1, 2)? .contiguous()?; let v = v - .reshape((b_sz, seq_len, self.n_heads, self.d_kv))? + .reshape((b_sz, kv_len, self.n_heads, self.d_kv))? .transpose(1, 2)? .contiguous()?; + // TODO: Use flash_attn. let scores = q.matmul(&k.t()?)?; + let scores = match mask { + None => scores, + Some(mask) => masked_fill( + &scores, + &mask + .unsqueeze(0)? + .unsqueeze(0)? + .repeat((b_sz, self.n_heads))?, + f32::NEG_INFINITY, + )?, + }; let (scores, position_bias) = match position_bias { Some(position_bias) => ( @@ -296,14 +331,12 @@ impl T5Attention { None => match &self.relative_attention_bias { None => (scores, None), Some(relative_attention_bias) => { - let query_length = seq_len; - let key_length = seq_len; // This only handles the bidirectional case. let num_buckets = self.relative_attention_num_buckets as u32 / 2; let max_exact = num_buckets / 2; - let relative_position = (0..query_length as u32) + let relative_position = (0..q_len as u32) .map(|i| { - (0..key_length as u32) + (0..kv_len as u32) .map(|j| { if i < j { if j - i < max_exact { @@ -348,7 +381,7 @@ impl T5Attention { let attn_output = attn_weights.matmul(&v)?; let attn_output = attn_output .transpose(1, 2)? - .reshape((b_sz, seq_len, self.inner_dim))?; + .reshape((b_sz, q_len, self.inner_dim))?; let attn_output = self.o.forward(&attn_output)?; Ok((attn_output, position_bias)) } @@ -375,24 +408,49 @@ impl T5LayerSelfAttention { &self, xs: &Tensor, position_bias: Option<&Tensor>, + mask: Option<&Tensor>, ) -> Result<(Tensor, Option<Tensor>)> { let normed_xs = self.layer_norm.forward(xs)?; - let (ys, position_bias) = self.self_attention.forward(&normed_xs, position_bias)?; + let (ys, position_bias) = + self.self_attention + .forward(&normed_xs, position_bias, None, mask)?; let ys = (xs + ys)?; Ok((ys, position_bias)) } } #[derive(Debug)] -struct T5LayerCrossAttention {} +struct T5LayerCrossAttention { + cross_attention: T5Attention, + layer_norm: T5LayerNorm, +} impl T5LayerCrossAttention { - fn load(_vb: VarBuilder, _cfg: &Config) -> Result<Self> { - todo!() + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let cross_attention = T5Attention::load(false, vb.pp("EncDecAttention"), cfg)?; + let layer_norm = + T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?; + Ok(Self { + cross_attention, + layer_norm, + }) } - fn forward(&self, _xs: &Tensor) -> Result<Tensor> { - todo!() + fn forward( + &self, + hidden_states: &Tensor, + position_bias: Option<&Tensor>, + key_value_states: &Tensor, + ) -> Result<(Tensor, Option<Tensor>)> { + let normed_hidden_states = self.layer_norm.forward(hidden_states)?; + let (ys, position_bias) = self.cross_attention.forward( + &normed_hidden_states, + position_bias, + Some(key_value_states), + None, + )?; + let ys = (hidden_states + ys)?; + Ok((ys, position_bias)) } } @@ -425,11 +483,17 @@ impl T5Block { &self, xs: &Tensor, position_bias: Option<&Tensor>, + encoder_hidden_states: Option<&Tensor>, ) -> Result<(Tensor, Option<Tensor>)> { - let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias)?; + // TODO: Cache masks + let mask = match self.cross_attn.is_some() { + true => Some(get_mask(xs.dim(1)?, xs.device())?), + false => None, + }; + let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias, mask.as_ref())?; // TODO: clamp for f16? if let Some(cross_attn) = &self.cross_attn { - xs = cross_attn.forward(&xs)?; + (xs, _) = cross_attn.forward(&xs, None, encoder_hidden_states.unwrap())?; // TODO: clamp for f16? } let xs = self.ff.forward(&xs)?; @@ -462,13 +526,20 @@ impl T5Stack { }) } - fn forward(&self, input_ids: &Tensor) -> Result<Tensor> { + fn forward( + &self, + input_ids: &Tensor, + encoder_hidden_states: Option<&Tensor>, + ) -> Result<Tensor> { let input_embeds = self.shared.as_ref().forward(input_ids)?; let mut hidden_states = input_embeds; let mut position_bias = None; for block in self.block.iter() { - (hidden_states, position_bias) = - block.forward(&hidden_states, position_bias.as_ref())? + (hidden_states, position_bias) = block.forward( + &hidden_states, + position_bias.as_ref(), + encoder_hidden_states, + )? } self.final_layer_norm.forward(&hidden_states) } @@ -492,7 +563,61 @@ impl T5EncoderModel { } pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> { - self.encoder.forward(input_ids) + self.encoder.forward(input_ids, None) + } + + pub fn device(&self) -> &Device { + &self.device + } +} + +#[derive(Debug)] +pub struct T5ForConditionalGeneration { + encoder: T5Stack, + decoder: T5Stack, + shared: Arc<Embedding>, + device: Device, +} + +impl T5ForConditionalGeneration { + pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + assert!(cfg.is_encoder_decoder); + let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?; + let shared = Arc::new(shared); + + let mut encoder_cfg = cfg.clone(); + encoder_cfg.is_decoder = false; + encoder_cfg.use_cache = false; + encoder_cfg.is_encoder_decoder = false; + let encoder = T5Stack::load(vb.pp("encoder"), &shared, &encoder_cfg)?; + + let mut decoder_cfg = cfg.clone(); + decoder_cfg.is_decoder = true; + decoder_cfg.is_encoder_decoder = false; + decoder_cfg.num_layers = cfg.num_decoder_layers.unwrap_or(cfg.num_layers); + let decoder = T5Stack::load(vb.pp("decoder"), &shared, &decoder_cfg)?; + + Ok(Self { + encoder, + decoder, + shared, + device: vb.device().clone(), + }) + } + + pub fn forward(&self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor> { + let encoder_output = self.encoder.forward(input_ids, None)?; + let decoder_output = self + .decoder + .forward(decoder_input_ids, Some(&encoder_output))?; + let sequence_output = decoder_output + .narrow(1, decoder_output.dim(1)? - 1, 1)? + .squeeze(1)?; + // TODO: check cfg.tie_word_embeddings to load from model instead. + let lm_head_weights = self.shared.embeddings().t()?; + let output = sequence_output.matmul(&lm_head_weights)?; + // TODO: Rescale output before projecting on vocab? * (self.model_dim**-0.5) + Ok(output) } pub fn device(&self) -> &Device { |