diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-17 09:00:45 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-17 08:00:45 +0100 |
commit | 1a276b5da79a4bb2305dde7368b800d165599819 (patch) | |
tree | 5270c7d9c0b6e345cfd65c3d74690c3488d65aa5 /candle-transformers/src/models/t5.rs | |
parent | 8658df348527cabcd722bfe2e9e48aba3c7f8e96 (diff) | |
download | candle-1a276b5da79a4bb2305dde7368b800d165599819.tar.gz candle-1a276b5da79a4bb2305dde7368b800d165599819.tar.bz2 candle-1a276b5da79a4bb2305dde7368b800d165599819.zip |
Add a KV cache to T5. (#873)
* Add a KV cache to T5.
* Suggest using release mode.
* Use the kv cache in decoding.
* Add a comment.
Diffstat (limited to 'candle-transformers/src/models/t5.rs')
-rw-r--r-- | candle-transformers/src/models/t5.rs | 85 |
1 files changed, 58 insertions, 27 deletions
diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index c35dea0b..8b621f64 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -54,7 +54,7 @@ pub struct Config { is_decoder: bool, is_encoder_decoder: bool, #[serde(default = "default_use_cache")] - use_cache: bool, + pub use_cache: bool, pub pad_token_id: usize, pub eos_token_id: usize, } @@ -245,10 +245,17 @@ struct T5Attention { relative_attention_num_buckets: usize, relative_attention_max_distance: usize, inner_dim: usize, + use_cache: bool, + kv_cache: Option<(Tensor, Tensor)>, } impl T5Attention { - fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> { + fn load( + has_relative_attention_bias: bool, + decoder: bool, + vb: VarBuilder, + cfg: &Config, + ) -> Result<Self> { let inner_dim = cfg.num_heads * cfg.d_kv; let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?; let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?; @@ -275,11 +282,13 @@ impl T5Attention { relative_attention_num_buckets: cfg.relative_attention_num_buckets, relative_attention_max_distance: cfg.relative_attention_max_distance, inner_dim, + use_cache: cfg.use_cache && decoder, + kv_cache: None, }) } fn forward( - &self, + &mut self, xs: &Tensor, position_bias: Option<&Tensor>, key_value_states: Option<&Tensor>, @@ -287,7 +296,6 @@ impl T5Attention { ) -> Result<(Tensor, Option<Tensor>)> { // Performs Self-attention (if key_value_states is None) or attention // over source sentence (provided by key_value_states). - // TODO: kv caching. let kv_input = match key_value_states { None => xs, Some(key_value_states) => key_value_states, @@ -301,14 +309,22 @@ impl T5Attention { .reshape((b_sz, q_len, self.n_heads, self.d_kv))? .transpose(1, 2)? .contiguous()?; - let k = k + let mut k = k .reshape((b_sz, kv_len, self.n_heads, self.d_kv))? .transpose(1, 2)? .contiguous()?; - let v = v + let mut v = v .reshape((b_sz, kv_len, self.n_heads, self.d_kv))? .transpose(1, 2)? .contiguous()?; + + if self.use_cache { + if let Some((kv_cache_k, kv_cache_v)) = &self.kv_cache { + k = Tensor::cat(&[kv_cache_k, &k], 2)?.contiguous()?; + v = Tensor::cat(&[kv_cache_v, &v], 2)?.contiguous()?; + }; + self.kv_cache = Some((k.clone(), v.clone())); + }; // TODO: Use flash_attn. let scores = q.matmul(&k.t()?)?; let scores = match mask { @@ -394,8 +410,8 @@ struct T5LayerSelfAttention { } impl T5LayerSelfAttention { - fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> { - let self_attention = T5Attention::load(h, vb.pp("SelfAttention"), cfg)?; + fn load(h: bool, d: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> { + let self_attention = T5Attention::load(h, d, vb.pp("SelfAttention"), cfg)?; let layer_norm = T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?; Ok(Self { @@ -405,7 +421,7 @@ impl T5LayerSelfAttention { } fn forward( - &self, + &mut self, xs: &Tensor, position_bias: Option<&Tensor>, mask: Option<&Tensor>, @@ -426,8 +442,8 @@ struct T5LayerCrossAttention { } impl T5LayerCrossAttention { - fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { - let cross_attention = T5Attention::load(false, vb.pp("EncDecAttention"), cfg)?; + fn load(decoder: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> { + let cross_attention = T5Attention::load(false, decoder, vb.pp("EncDecAttention"), cfg)?; let layer_norm = T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?; Ok(Self { @@ -437,7 +453,7 @@ impl T5LayerCrossAttention { } fn forward( - &self, + &mut self, hidden_states: &Tensor, position_bias: Option<&Tensor>, key_value_states: &Tensor, @@ -462,11 +478,17 @@ struct T5Block { } impl T5Block { - fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> { + fn load( + has_relative_attention_bias: bool, + decoder: bool, + vb: VarBuilder, + cfg: &Config, + ) -> Result<Self> { let vb = vb.pp("layer"); - let self_attn = T5LayerSelfAttention::load(has_relative_attention_bias, vb.pp("0"), cfg)?; + let self_attn = + T5LayerSelfAttention::load(has_relative_attention_bias, decoder, vb.pp("0"), cfg)?; let cross_attn = if cfg.is_decoder { - Some(T5LayerCrossAttention::load(vb.pp("1"), cfg)?) + Some(T5LayerCrossAttention::load(decoder, vb.pp("1"), cfg)?) } else { None }; @@ -480,19 +502,28 @@ impl T5Block { } fn forward( - &self, + &mut self, xs: &Tensor, position_bias: Option<&Tensor>, encoder_hidden_states: Option<&Tensor>, ) -> Result<(Tensor, Option<Tensor>)> { // TODO: Cache masks let mask = match self.cross_attn.is_some() { - true => Some(get_mask(xs.dim(1)?, xs.device())?), + true => { + let mask_len = xs.dim(1)?; + // If the input seq length is 1, no need for a mask, this is also helpful to avoid shape + // issues when using the KV cache in the decoder. + if mask_len <= 1 { + None + } else { + Some(get_mask(mask_len, xs.device())?) + } + } false => None, }; let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias, mask.as_ref())?; // TODO: clamp for f16? - if let Some(cross_attn) = &self.cross_attn { + if let Some(cross_attn) = &mut self.cross_attn { (xs, _) = cross_attn.forward(&xs, None, encoder_hidden_states.unwrap())?; // TODO: clamp for f16? } @@ -510,9 +541,9 @@ struct T5Stack { } impl T5Stack { - fn load(vb: VarBuilder, shared: &Arc<Embedding>, cfg: &Config) -> Result<Self> { + fn load(decoder: bool, vb: VarBuilder, shared: &Arc<Embedding>, cfg: &Config) -> Result<Self> { let block = (0..cfg.num_layers) - .map(|i| T5Block::load(i == 0, vb.pp(&format!("block.{i}")), cfg)) + .map(|i| T5Block::load(i == 0, decoder, vb.pp(&format!("block.{i}")), cfg)) .collect::<Result<Vec<_>>>()?; let final_layer_norm = T5LayerNorm::load( cfg.d_model, @@ -527,14 +558,14 @@ impl T5Stack { } fn forward( - &self, + &mut self, input_ids: &Tensor, encoder_hidden_states: Option<&Tensor>, ) -> Result<Tensor> { let input_embeds = self.shared.as_ref().forward(input_ids)?; let mut hidden_states = input_embeds; let mut position_bias = None; - for block in self.block.iter() { + for block in self.block.iter_mut() { (hidden_states, position_bias) = block.forward( &hidden_states, position_bias.as_ref(), @@ -555,14 +586,14 @@ impl T5EncoderModel { pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?; let shared = Arc::new(shared); - let encoder = T5Stack::load(vb.pp("encoder"), &shared, cfg)?; + let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?; Ok(Self { encoder, device: vb.device().clone(), }) } - pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> { + pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> { self.encoder.forward(input_ids, None) } @@ -589,13 +620,13 @@ impl T5ForConditionalGeneration { encoder_cfg.is_decoder = false; encoder_cfg.use_cache = false; encoder_cfg.is_encoder_decoder = false; - let encoder = T5Stack::load(vb.pp("encoder"), &shared, &encoder_cfg)?; + let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, &encoder_cfg)?; let mut decoder_cfg = cfg.clone(); decoder_cfg.is_decoder = true; decoder_cfg.is_encoder_decoder = false; decoder_cfg.num_layers = cfg.num_decoder_layers.unwrap_or(cfg.num_layers); - let decoder = T5Stack::load(vb.pp("decoder"), &shared, &decoder_cfg)?; + let decoder = T5Stack::load(true, vb.pp("decoder"), &shared, &decoder_cfg)?; Ok(Self { encoder, @@ -605,7 +636,7 @@ impl T5ForConditionalGeneration { }) } - pub fn forward(&self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor> { + pub fn forward(&mut self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor> { let encoder_output = self.encoder.forward(input_ids, None)?; let decoder_output = self .decoder |