diff options
author | Hugo Abonizio <hugo_abonizio@hotmail.com> | 2024-03-25 14:20:09 -0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-25 18:20:09 +0100 |
commit | 60676780a9436fd0de43b1e8ff99445ab863c066 (patch) | |
tree | 206625a6f144e5c81ae48468aa915a0d27c11eb6 /candle-nn | |
parent | d3a8d291d5f2ff5addb9ff97cf881307afbd7b6a (diff) | |
download | candle-60676780a9436fd0de43b1e8ff99445ab863c066.tar.gz candle-60676780a9436fd0de43b1e8ff99445ab863c066.tar.bz2 candle-60676780a9436fd0de43b1e8ff99445ab863c066.zip |
Fix detail in new RoPE implementation (#1935)
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/src/rotary_emb.rs | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/candle-nn/src/rotary_emb.rs b/candle-nn/src/rotary_emb.rs index 9c5543fb..c2b41482 100644 --- a/candle-nn/src/rotary_emb.rs +++ b/candle-nn/src/rotary_emb.rs @@ -455,7 +455,7 @@ impl candle::CustomOp3 for RotaryEmb { pub fn rope(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> { let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?; let (cos_seq_len, cos_n_embd) = cos.dims2()?; - let (sin_seq_len, sin_n_embd) = cos.dims2()?; + let (sin_seq_len, sin_n_embd) = sin.dims2()?; if cos_n_embd * 2 != n_embd || sin_n_embd * 2 != n_embd || seq_len > cos_seq_len |