diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-04-19 21:49:55 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-19 21:49:55 +0200 |
commit | b45c710dbf61445751ae56052131ccd40a25b6b8 (patch) | |
tree | 57e9f8e417f663536d5d0a3691dad70f4876f0f0 /candle-transformers | |
parent | 9c532aef4751ad33cb74bb81b506cdb3011b5bef (diff) | |
download | candle-b45c710dbf61445751ae56052131ccd40a25b6b8.tar.gz candle-b45c710dbf61445751ae56052131ccd40a25b6b8.tar.bz2 candle-b45c710dbf61445751ae56052131ccd40a25b6b8.zip |
Fix for gemma MQA. (#2091)
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/gemma.rs | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/candle-transformers/src/models/gemma.rs b/candle-transformers/src/models/gemma.rs index 58b5f1e1..3bde88b4 100644 --- a/candle-transformers/src/models/gemma.rs +++ b/candle-transformers/src/models/gemma.rs @@ -227,8 +227,9 @@ impl Attention { }; self.kv_cache = Some((key_states.clone(), value_states.clone())); - let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?; - let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?; + let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?; + let value_states = + crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?; let attn_output = { let scale = 1f64 / f64::sqrt(self.head_dim as f64); |