summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/starcoder2.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/starcoder2.rs')
-rw-r--r--candle-transformers/src/models/starcoder2.rs16
1 files changed, 2 insertions, 14 deletions
diff --git a/candle-transformers/src/models/starcoder2.rs b/candle-transformers/src/models/starcoder2.rs
index da3f6799..d108d062 100644
--- a/candle-transformers/src/models/starcoder2.rs
+++ b/candle-transformers/src/models/starcoder2.rs
@@ -139,18 +139,6 @@ impl Attention {
})
}
- fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
- let n_rep = self.num_kv_groups;
- if n_rep == 1 {
- Ok(xs)
- } else {
- let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
- xs.unsqueeze(2)?
- .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
- .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
- }
- }
-
fn forward(
&mut self,
xs: &Tensor,
@@ -187,8 +175,8 @@ impl Attention {
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
- let key_states = self.repeat_kv(key_states)?;
- let value_states = self.repeat_kv(value_states)?;
+ let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
+ let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;