summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-18 22:34:29 +0200
committerGitHub <noreply@github.com>2024-04-18 22:34:29 +0200
commit2b93dffe64d26829224f0f31e81f6c50c0e1e733 (patch)
treed1ad08ba2f76775b6b69f723d57449cd67ae91c4 /candle-transformers
parente6ee7ba4d46de6e5e1e003319da4a49a3a7a0813 (diff)
downloadcandle-2b93dffe64d26829224f0f31e81f6c50c0e1e733.tar.gz
candle-2b93dffe64d26829224f0f31e81f6c50c0e1e733.tar.bz2
candle-2b93dffe64d26829224f0f31e81f6c50c0e1e733.zip
Use faster rotary embeddings for llama like models. (#2087)
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/llama.rs17
1 files changed, 6 insertions, 11 deletions
diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs
index 97a40d37..945c0e17 100644
--- a/candle-transformers/src/models/llama.rs
+++ b/candle-transformers/src/models/llama.rs
@@ -116,7 +116,6 @@ impl Cache {
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
// This is different from the paper, see:
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
- let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
let cos = idx_theta.cos()?.to_dtype(dtype)?;
let sin = idx_theta.sin()?.to_dtype(dtype)?;
Ok(Self {
@@ -176,16 +175,10 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten
impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result<Tensor> {
let _enter = self.span_rot.enter();
- let (b_sz, _, seq_len, hidden_size) = x.dims4()?;
+ let (_b_sz, _, seq_len, _hidden_size) = x.dims4()?;
let cos = cache.cos.narrow(0, index_pos, seq_len)?;
let sin = cache.sin.narrow(0, index_pos, seq_len)?;
- let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
- let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
- let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?;
- let x2 = x.narrow(D::Minus1, hidden_size / 2, hidden_size / 2)?;
- let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
- let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?;
- Ok(rope)
+ candle_nn::rotary_emb::rope(x, &cos, &sin)
}
fn forward(
@@ -203,10 +196,12 @@ impl CausalSelfAttention {
let q = q
.reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
- .transpose(1, 2)?;
+ .transpose(1, 2)?
+ .contiguous()?;
let k = k
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
- .transpose(1, 2)?;
+ .transpose(1, 2)?
+ .contiguous()?;
let mut v = v
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
.transpose(1, 2)?;