diff options
Diffstat (limited to 'candle-transformers/src/utils.rs')
-rw-r--r-- | candle-transformers/src/utils.rs | 14 |
1 files changed, 14 insertions, 0 deletions
diff --git a/candle-transformers/src/utils.rs b/candle-transformers/src/utils.rs index d29995ed..17e83694 100644 --- a/candle-transformers/src/utils.rs +++ b/candle-transformers/src/utils.rs @@ -20,3 +20,17 @@ pub fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> R let logits_len = logits.len(); Tensor::from_vec(logits, logits_len, device) } + +/// Repeats a key or value tensor for grouped query attention +/// The input tensor should have a shape `(batch, num_kv_heads, seq_len, head_dim)`, +pub fn repeat_kv(xs: Tensor, n_rep: usize) -> Result<Tensor> { + if n_rep == 1 { + Ok(xs) + } else { + let (b_sz, n_kv_head, seq_len, head_dim) = xs.dims4()?; + // Using cat is faster than a broadcast as it avoids going through a potentially + // strided copy. + // https://github.com/huggingface/candle/pull/2043 + Tensor::cat(&vec![&xs; n_rep], 2)?.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim)) + } +} |