summaryrefslogtreecommitdiff
path: root/candle-transformers/src/utils.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/utils.rs')
-rw-r--r--candle-transformers/src/utils.rs14
1 files changed, 14 insertions, 0 deletions
diff --git a/candle-transformers/src/utils.rs b/candle-transformers/src/utils.rs
index d29995ed..17e83694 100644
--- a/candle-transformers/src/utils.rs
+++ b/candle-transformers/src/utils.rs
@@ -20,3 +20,17 @@ pub fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> R
let logits_len = logits.len();
Tensor::from_vec(logits, logits_len, device)
}
+
+/// Repeats a key or value tensor for grouped query attention
+/// The input tensor should have a shape `(batch, num_kv_heads, seq_len, head_dim)`,
+pub fn repeat_kv(xs: Tensor, n_rep: usize) -> Result<Tensor> {
+ if n_rep == 1 {
+ Ok(xs)
+ } else {
+ let (b_sz, n_kv_head, seq_len, head_dim) = xs.dims4()?;
+ // Using cat is faster than a broadcast as it avoids going through a potentially
+ // strided copy.
+ // https://github.com/huggingface/candle/pull/2043
+ Tensor::cat(&vec![&xs; n_rep], 2)?.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))
+ }
+}