summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-19 21:49:55 +0200
committerGitHub <noreply@github.com>2024-04-19 21:49:55 +0200
commitb45c710dbf61445751ae56052131ccd40a25b6b8 (patch)
tree57e9f8e417f663536d5d0a3691dad70f4876f0f0 /candle-transformers
parent9c532aef4751ad33cb74bb81b506cdb3011b5bef (diff)
downloadcandle-b45c710dbf61445751ae56052131ccd40a25b6b8.tar.gz
candle-b45c710dbf61445751ae56052131ccd40a25b6b8.tar.bz2
candle-b45c710dbf61445751ae56052131ccd40a25b6b8.zip
Fix for gemma MQA. (#2091)
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/gemma.rs5
1 files changed, 3 insertions, 2 deletions
diff --git a/candle-transformers/src/models/gemma.rs b/candle-transformers/src/models/gemma.rs
index 58b5f1e1..3bde88b4 100644
--- a/candle-transformers/src/models/gemma.rs
+++ b/candle-transformers/src/models/gemma.rs
@@ -227,8 +227,9 @@ impl Attention {
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
- let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
- let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
+ let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
+ let value_states =
+ crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
let attn_output = {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);