summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/quantized_mistral.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-12 09:15:10 +0200
committerGitHub <noreply@github.com>2024-04-12 09:15:10 +0200
commit3ad4770eb61be34e6d2a7914a935b007d8dee49f (patch)
tree17bf212fa44c7d4370b5877ddf0d92c136e110b0 /candle-transformers/src/models/quantized_mistral.rs
parenta0460cd2b13a396ff8545dc1bbffa741f2ec3d79 (diff)
downloadcandle-3ad4770eb61be34e6d2a7914a935b007d8dee49f.tar.gz
candle-3ad4770eb61be34e6d2a7914a935b007d8dee49f.tar.bz2
candle-3ad4770eb61be34e6d2a7914a935b007d8dee49f.zip
Use cat for faster MQA computation. (#2043)
* Use cat for faster MQA computation. * Move the function to utils + use it in mistral. * Use the shared repeat-kv in a few more models. * Fix.
Diffstat (limited to 'candle-transformers/src/models/quantized_mistral.rs')
-rw-r--r--candle-transformers/src/models/quantized_mistral.rs16
1 files changed, 2 insertions, 14 deletions
diff --git a/candle-transformers/src/models/quantized_mistral.rs b/candle-transformers/src/models/quantized_mistral.rs
index e37785de..0583810a 100644
--- a/candle-transformers/src/models/quantized_mistral.rs
+++ b/candle-transformers/src/models/quantized_mistral.rs
@@ -122,18 +122,6 @@ impl Attention {
})
}
- fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
- let n_rep = self.num_kv_groups;
- if n_rep == 1 {
- Ok(xs)
- } else {
- let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
- xs.unsqueeze(2)?
- .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
- .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
- }
- }
-
fn forward(
&mut self,
xs: &Tensor,
@@ -172,8 +160,8 @@ impl Attention {
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
- let key_states = self.repeat_kv(key_states)?;
- let value_states = self.repeat_kv(value_states)?;
+ let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
+ let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
let attn_output = {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);