diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-31 09:47:44 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-31 08:47:44 +0000 |
commit | c12ad45562778ffff0cda6c623b4838a2ed1c57c (patch) | |
tree | 57d249d8cd21386cfb2764da67324c430e7a97b4 /candle-transformers | |
parent | 7d0202710bba5b95d5161bd81528e90f0a406acc (diff) | |
download | candle-c12ad45562778ffff0cda6c623b4838a2ed1c57c.tar.gz candle-c12ad45562778ffff0cda6c623b4838a2ed1c57c.tar.bz2 candle-c12ad45562778ffff0cda6c623b4838a2ed1c57c.zip |
Add a KV cache to marian decoding. (#1226)
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/marian.rs | 54 |
1 files changed, 40 insertions, 14 deletions
diff --git a/candle-transformers/src/models/marian.rs b/candle-transformers/src/models/marian.rs index 5305d4d8..ebab3dbc 100644 --- a/candle-transformers/src/models/marian.rs +++ b/candle-transformers/src/models/marian.rs @@ -126,6 +126,8 @@ struct Attention { scaling: f64, num_heads: usize, head_dim: usize, + kv_cache: Option<(Tensor, Tensor)>, + is_decoder: bool, } impl Attention { @@ -150,6 +152,8 @@ impl Attention { scaling, num_heads, head_dim, + kv_cache: None, + is_decoder, }) } @@ -161,7 +165,7 @@ impl Attention { } fn forward( - &self, + &mut self, xs: &Tensor, kv_states: Option<&Tensor>, attn_mask: Option<&Tensor>, @@ -173,7 +177,20 @@ impl Attention { None => { let key_states = self._shape(&xs.apply(&self.k_proj)?, b_sz)?; let value_states = self._shape(&xs.apply(&self.v_proj)?, b_sz)?; - (key_states, value_states) + if self.is_decoder { + let kv_states = match &self.kv_cache { + None => (key_states, value_states), + Some((p_key_states, p_value_states)) => { + let key_states = Tensor::cat(&[p_key_states, &key_states], 2)?; + let value_states = Tensor::cat(&[p_value_states, &value_states], 2)?; + (key_states, value_states) + } + }; + self.kv_cache = Some(kv_states.clone()); + kv_states + } else { + (key_states, value_states) + } } Some(kv_states) => { let key_states = self._shape(&kv_states.apply(&self.k_proj)?, b_sz)?; @@ -198,6 +215,10 @@ impl Attention { .reshape((b_sz, tgt_len, self.head_dim * self.num_heads))? .apply(&self.out_proj) } + + fn reset_kv_cache(&mut self) { + self.kv_cache = None + } } #[derive(Debug, Clone)] @@ -227,7 +248,7 @@ impl EncoderLayer { }) } - fn forward(&self, xs: &Tensor) -> Result<Tensor> { + fn forward(&mut self, xs: &Tensor) -> Result<Tensor> { let residual = xs; let xs = (self.self_attn.forward(xs, None, None)? + residual)? .apply(&self.self_attn_layer_norm)?; @@ -275,7 +296,7 @@ impl DecoderLayer { } fn forward( - &self, + &mut self, xs: &Tensor, encoder_xs: Option<&Tensor>, attn_mask: &Tensor, @@ -331,7 +352,7 @@ impl Encoder { }) } - pub fn forward(&self, xs: &Tensor, past_kv_len: usize) -> Result<Tensor> { + pub fn forward(&mut self, xs: &Tensor, past_kv_len: usize) -> Result<Tensor> { let xs = xs.apply(&self.embed_tokens)?; let xs = match self.embed_scale { None => xs, @@ -342,7 +363,7 @@ impl Encoder { .forward(&xs, past_kv_len)? .unsqueeze(0)?; let mut xs = xs.broadcast_add(&embed_pos)?; - for layer in self.layers.iter() { + for layer in self.layers.iter_mut() { xs = layer.forward(&xs)? } Ok(xs) @@ -380,7 +401,7 @@ impl Decoder { } pub fn forward( - &self, + &mut self, xs: &Tensor, encoder_xs: Option<&Tensor>, past_kv_len: usize, @@ -396,7 +417,7 @@ impl Decoder { .forward(&xs, past_kv_len)? .unsqueeze(0)?; let mut xs = xs.broadcast_add(&embed_pos)?; - for layer in self.layers.iter() { + for layer in self.layers.iter_mut() { xs = layer.forward(&xs, encoder_xs, attn_mask)?; } Ok(xs) @@ -443,15 +464,20 @@ impl MTModel { }) } - pub fn encoder(&self) -> &Encoder { - &self.model.encoder + pub fn encoder(&mut self) -> &mut Encoder { + &mut self.model.encoder } - pub fn decoder(&self) -> &Decoder { - &self.model.decoder + pub fn decoder(&mut self) -> &mut Decoder { + &mut self.model.decoder } - pub fn decode(&self, xs: &Tensor, encoder_xs: &Tensor) -> Result<Tensor> { + pub fn decode( + &mut self, + xs: &Tensor, + encoder_xs: &Tensor, + past_kv_len: usize, + ) -> Result<Tensor> { let seq_len = xs.dim(1)?; let mask: Vec<_> = (0..seq_len) .flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 })) @@ -459,7 +485,7 @@ impl MTModel { let mask = Tensor::from_vec(mask, (seq_len, seq_len), xs.device())?; self.model .decoder - .forward(xs, Some(encoder_xs), 0, &mask)? + .forward(xs, Some(encoder_xs), past_kv_len, &mask)? .apply(&self.lm_head)? .broadcast_add(&self.final_logits_bias) } |