summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHugo Abonizio <hugo_abonizio@hotmail.com>2024-03-25 14:20:09 -0300
committerGitHub <noreply@github.com>2024-03-25 18:20:09 +0100
commit60676780a9436fd0de43b1e8ff99445ab863c066 (patch)
tree206625a6f144e5c81ae48468aa915a0d27c11eb6
parentd3a8d291d5f2ff5addb9ff97cf881307afbd7b6a (diff)
downloadcandle-60676780a9436fd0de43b1e8ff99445ab863c066.tar.gz
candle-60676780a9436fd0de43b1e8ff99445ab863c066.tar.bz2
candle-60676780a9436fd0de43b1e8ff99445ab863c066.zip
Fix detail in new RoPE implementation (#1935)
-rw-r--r--candle-nn/src/rotary_emb.rs2
1 files changed, 1 insertions, 1 deletions
diff --git a/candle-nn/src/rotary_emb.rs b/candle-nn/src/rotary_emb.rs
index 9c5543fb..c2b41482 100644
--- a/candle-nn/src/rotary_emb.rs
+++ b/candle-nn/src/rotary_emb.rs
@@ -455,7 +455,7 @@ impl candle::CustomOp3 for RotaryEmb {
pub fn rope(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
let (cos_seq_len, cos_n_embd) = cos.dims2()?;
- let (sin_seq_len, sin_n_embd) = cos.dims2()?;
+ let (sin_seq_len, sin_n_embd) = sin.dims2()?;
if cos_n_embd * 2 != n_embd
|| sin_n_embd * 2 != n_embd
|| seq_len > cos_seq_len