diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-04-18 22:34:29 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-18 22:34:29 +0200 |
commit | 2b93dffe64d26829224f0f31e81f6c50c0e1e733 (patch) | |
tree | d1ad08ba2f76775b6b69f723d57449cd67ae91c4 /candle-transformers | |
parent | e6ee7ba4d46de6e5e1e003319da4a49a3a7a0813 (diff) | |
download | candle-2b93dffe64d26829224f0f31e81f6c50c0e1e733.tar.gz candle-2b93dffe64d26829224f0f31e81f6c50c0e1e733.tar.bz2 candle-2b93dffe64d26829224f0f31e81f6c50c0e1e733.zip |
Use faster rotary embeddings for llama like models. (#2087)
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/llama.rs | 17 |
1 files changed, 6 insertions, 11 deletions
diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index 97a40d37..945c0e17 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -116,7 +116,6 @@ impl Cache { .matmul(&theta.reshape((1, theta.elem_count()))?)?; // This is different from the paper, see: // https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112 - let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?; let cos = idx_theta.cos()?.to_dtype(dtype)?; let sin = idx_theta.sin()?.to_dtype(dtype)?; Ok(Self { @@ -176,16 +175,10 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten impl CausalSelfAttention { fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result<Tensor> { let _enter = self.span_rot.enter(); - let (b_sz, _, seq_len, hidden_size) = x.dims4()?; + let (_b_sz, _, seq_len, _hidden_size) = x.dims4()?; let cos = cache.cos.narrow(0, index_pos, seq_len)?; let sin = cache.sin.narrow(0, index_pos, seq_len)?; - let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?; - let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?; - let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?; - let x2 = x.narrow(D::Minus1, hidden_size / 2, hidden_size / 2)?; - let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?; - let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?; - Ok(rope) + candle_nn::rotary_emb::rope(x, &cos, &sin) } fn forward( @@ -203,10 +196,12 @@ impl CausalSelfAttention { let q = q .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))? - .transpose(1, 2)?; + .transpose(1, 2)? + .contiguous()?; let k = k .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))? - .transpose(1, 2)?; + .transpose(1, 2)? + .contiguous()?; let mut v = v .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))? .transpose(1, 2)?; |