diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-04-12 09:15:10 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-12 09:15:10 +0200 |
commit | 3ad4770eb61be34e6d2a7914a935b007d8dee49f (patch) | |
tree | 17bf212fa44c7d4370b5877ddf0d92c136e110b0 /candle-transformers/src/models/starcoder2.rs | |
parent | a0460cd2b13a396ff8545dc1bbffa741f2ec3d79 (diff) | |
download | candle-3ad4770eb61be34e6d2a7914a935b007d8dee49f.tar.gz candle-3ad4770eb61be34e6d2a7914a935b007d8dee49f.tar.bz2 candle-3ad4770eb61be34e6d2a7914a935b007d8dee49f.zip |
Use cat for faster MQA computation. (#2043)
* Use cat for faster MQA computation.
* Move the function to utils + use it in mistral.
* Use the shared repeat-kv in a few more models.
* Fix.
Diffstat (limited to 'candle-transformers/src/models/starcoder2.rs')
-rw-r--r-- | candle-transformers/src/models/starcoder2.rs | 16 |
1 files changed, 2 insertions, 14 deletions
diff --git a/candle-transformers/src/models/starcoder2.rs b/candle-transformers/src/models/starcoder2.rs index da3f6799..d108d062 100644 --- a/candle-transformers/src/models/starcoder2.rs +++ b/candle-transformers/src/models/starcoder2.rs @@ -139,18 +139,6 @@ impl Attention { }) } - fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> { - let n_rep = self.num_kv_groups; - if n_rep == 1 { - Ok(xs) - } else { - let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?; - xs.unsqueeze(2)? - .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))? - .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim)) - } - } - fn forward( &mut self, xs: &Tensor, @@ -187,8 +175,8 @@ impl Attention { }; self.kv_cache = Some((key_states.clone(), value_states.clone())); - let key_states = self.repeat_kv(key_states)?; - let value_states = self.repeat_kv(value_states)?; + let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?; + let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?; let scale = 1f64 / f64::sqrt(self.head_dim as f64); let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?; |