summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/quantized_mpt.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-12 09:15:10 +0200
committerGitHub <noreply@github.com>2024-04-12 09:15:10 +0200
commit3ad4770eb61be34e6d2a7914a935b007d8dee49f (patch)
tree17bf212fa44c7d4370b5877ddf0d92c136e110b0 /candle-transformers/src/models/quantized_mpt.rs
parenta0460cd2b13a396ff8545dc1bbffa741f2ec3d79 (diff)
downloadcandle-3ad4770eb61be34e6d2a7914a935b007d8dee49f.tar.gz
candle-3ad4770eb61be34e6d2a7914a935b007d8dee49f.tar.bz2
candle-3ad4770eb61be34e6d2a7914a935b007d8dee49f.zip
Use cat for faster MQA computation. (#2043)
* Use cat for faster MQA computation. * Move the function to utils + use it in mistral. * Use the shared repeat-kv in a few more models. * Fix.
Diffstat (limited to 'candle-transformers/src/models/quantized_mpt.rs')
-rw-r--r--candle-transformers/src/models/quantized_mpt.rs4
1 files changed, 2 insertions, 2 deletions
diff --git a/candle-transformers/src/models/quantized_mpt.rs b/candle-transformers/src/models/quantized_mpt.rs
index 70a9e125..056fcac2 100644
--- a/candle-transformers/src/models/quantized_mpt.rs
+++ b/candle-transformers/src/models/quantized_mpt.rs
@@ -71,8 +71,8 @@ impl GroupedQueryAttention {
};
self.kv_cache = Some((key.clone(), value.clone()));
let query = query.contiguous()?;
- let key = super::mpt::repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?;
- let value = super::mpt::repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?;
+ let key = crate::utils::repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?;
+ let value = crate::utils::repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?;
let attn_weights = (query.matmul(&key)? * self.softmax_scale)?;
let attn_bias = {
let s_q = query.dim(D::Minus2)?;