summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-31 09:47:44 +0100
committerGitHub <noreply@github.com>2023-10-31 08:47:44 +0000
commitc12ad45562778ffff0cda6c623b4838a2ed1c57c (patch)
tree57d249d8cd21386cfb2764da67324c430e7a97b4 /candle-transformers
parent7d0202710bba5b95d5161bd81528e90f0a406acc (diff)
downloadcandle-c12ad45562778ffff0cda6c623b4838a2ed1c57c.tar.gz
candle-c12ad45562778ffff0cda6c623b4838a2ed1c57c.tar.bz2
candle-c12ad45562778ffff0cda6c623b4838a2ed1c57c.zip
Add a KV cache to marian decoding. (#1226)
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/marian.rs54
1 files changed, 40 insertions, 14 deletions
diff --git a/candle-transformers/src/models/marian.rs b/candle-transformers/src/models/marian.rs
index 5305d4d8..ebab3dbc 100644
--- a/candle-transformers/src/models/marian.rs
+++ b/candle-transformers/src/models/marian.rs
@@ -126,6 +126,8 @@ struct Attention {
scaling: f64,
num_heads: usize,
head_dim: usize,
+ kv_cache: Option<(Tensor, Tensor)>,
+ is_decoder: bool,
}
impl Attention {
@@ -150,6 +152,8 @@ impl Attention {
scaling,
num_heads,
head_dim,
+ kv_cache: None,
+ is_decoder,
})
}
@@ -161,7 +165,7 @@ impl Attention {
}
fn forward(
- &self,
+ &mut self,
xs: &Tensor,
kv_states: Option<&Tensor>,
attn_mask: Option<&Tensor>,
@@ -173,7 +177,20 @@ impl Attention {
None => {
let key_states = self._shape(&xs.apply(&self.k_proj)?, b_sz)?;
let value_states = self._shape(&xs.apply(&self.v_proj)?, b_sz)?;
- (key_states, value_states)
+ if self.is_decoder {
+ let kv_states = match &self.kv_cache {
+ None => (key_states, value_states),
+ Some((p_key_states, p_value_states)) => {
+ let key_states = Tensor::cat(&[p_key_states, &key_states], 2)?;
+ let value_states = Tensor::cat(&[p_value_states, &value_states], 2)?;
+ (key_states, value_states)
+ }
+ };
+ self.kv_cache = Some(kv_states.clone());
+ kv_states
+ } else {
+ (key_states, value_states)
+ }
}
Some(kv_states) => {
let key_states = self._shape(&kv_states.apply(&self.k_proj)?, b_sz)?;
@@ -198,6 +215,10 @@ impl Attention {
.reshape((b_sz, tgt_len, self.head_dim * self.num_heads))?
.apply(&self.out_proj)
}
+
+ fn reset_kv_cache(&mut self) {
+ self.kv_cache = None
+ }
}
#[derive(Debug, Clone)]
@@ -227,7 +248,7 @@ impl EncoderLayer {
})
}
- fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
let residual = xs;
let xs = (self.self_attn.forward(xs, None, None)? + residual)?
.apply(&self.self_attn_layer_norm)?;
@@ -275,7 +296,7 @@ impl DecoderLayer {
}
fn forward(
- &self,
+ &mut self,
xs: &Tensor,
encoder_xs: Option<&Tensor>,
attn_mask: &Tensor,
@@ -331,7 +352,7 @@ impl Encoder {
})
}
- pub fn forward(&self, xs: &Tensor, past_kv_len: usize) -> Result<Tensor> {
+ pub fn forward(&mut self, xs: &Tensor, past_kv_len: usize) -> Result<Tensor> {
let xs = xs.apply(&self.embed_tokens)?;
let xs = match self.embed_scale {
None => xs,
@@ -342,7 +363,7 @@ impl Encoder {
.forward(&xs, past_kv_len)?
.unsqueeze(0)?;
let mut xs = xs.broadcast_add(&embed_pos)?;
- for layer in self.layers.iter() {
+ for layer in self.layers.iter_mut() {
xs = layer.forward(&xs)?
}
Ok(xs)
@@ -380,7 +401,7 @@ impl Decoder {
}
pub fn forward(
- &self,
+ &mut self,
xs: &Tensor,
encoder_xs: Option<&Tensor>,
past_kv_len: usize,
@@ -396,7 +417,7 @@ impl Decoder {
.forward(&xs, past_kv_len)?
.unsqueeze(0)?;
let mut xs = xs.broadcast_add(&embed_pos)?;
- for layer in self.layers.iter() {
+ for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, encoder_xs, attn_mask)?;
}
Ok(xs)
@@ -443,15 +464,20 @@ impl MTModel {
})
}
- pub fn encoder(&self) -> &Encoder {
- &self.model.encoder
+ pub fn encoder(&mut self) -> &mut Encoder {
+ &mut self.model.encoder
}
- pub fn decoder(&self) -> &Decoder {
- &self.model.decoder
+ pub fn decoder(&mut self) -> &mut Decoder {
+ &mut self.model.decoder
}
- pub fn decode(&self, xs: &Tensor, encoder_xs: &Tensor) -> Result<Tensor> {
+ pub fn decode(
+ &mut self,
+ xs: &Tensor,
+ encoder_xs: &Tensor,
+ past_kv_len: usize,
+ ) -> Result<Tensor> {
let seq_len = xs.dim(1)?;
let mask: Vec<_> = (0..seq_len)
.flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
@@ -459,7 +485,7 @@ impl MTModel {
let mask = Tensor::from_vec(mask, (seq_len, seq_len), xs.device())?;
self.model
.decoder
- .forward(xs, Some(encoder_xs), 0, &mask)?
+ .forward(xs, Some(encoder_xs), past_kv_len, &mask)?
.apply(&self.lm_head)?
.broadcast_add(&self.final_logits_bias)
}