diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-04-12 09:15:10 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-12 09:15:10 +0200 |
commit | 3ad4770eb61be34e6d2a7914a935b007d8dee49f (patch) | |
tree | 17bf212fa44c7d4370b5877ddf0d92c136e110b0 /candle-transformers/src/models/quantized_mpt.rs | |
parent | a0460cd2b13a396ff8545dc1bbffa741f2ec3d79 (diff) | |
download | candle-3ad4770eb61be34e6d2a7914a935b007d8dee49f.tar.gz candle-3ad4770eb61be34e6d2a7914a935b007d8dee49f.tar.bz2 candle-3ad4770eb61be34e6d2a7914a935b007d8dee49f.zip |
Use cat for faster MQA computation. (#2043)
* Use cat for faster MQA computation.
* Move the function to utils + use it in mistral.
* Use the shared repeat-kv in a few more models.
* Fix.
Diffstat (limited to 'candle-transformers/src/models/quantized_mpt.rs')
-rw-r--r-- | candle-transformers/src/models/quantized_mpt.rs | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/candle-transformers/src/models/quantized_mpt.rs b/candle-transformers/src/models/quantized_mpt.rs index 70a9e125..056fcac2 100644 --- a/candle-transformers/src/models/quantized_mpt.rs +++ b/candle-transformers/src/models/quantized_mpt.rs @@ -71,8 +71,8 @@ impl GroupedQueryAttention { }; self.kv_cache = Some((key.clone(), value.clone())); let query = query.contiguous()?; - let key = super::mpt::repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?; - let value = super::mpt::repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?; + let key = crate::utils::repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?; + let value = crate::utils::repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?; let attn_weights = (query.matmul(&key)? * self.softmax_scale)?; let attn_bias = { let s_q = query.dim(D::Minus2)?; |