summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-05-24 16:48:21 +0200
committerGitHub <noreply@github.com>2024-05-24 16:48:21 +0200
commit3ceca9901a5ebc4ded3ac2cd793d0125f7a12562 (patch)
tree364793408840c261956f04fed2b0caf430655c41 /candle-transformers
parent1df2bddccfbb4ab511a8cc3a87476d1fa72416bc (diff)
downloadcandle-3ceca9901a5ebc4ded3ac2cd793d0125f7a12562.tar.gz
candle-3ceca9901a5ebc4ded3ac2cd793d0125f7a12562.tar.bz2
candle-3ceca9901a5ebc4ded3ac2cd793d0125f7a12562.zip
Enable the new layer-norm. (#2213)
* Enable the new layer-norm. * Shape fixes.
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/phi.rs12
1 files changed, 4 insertions, 8 deletions
diff --git a/candle-transformers/src/models/phi.rs b/candle-transformers/src/models/phi.rs
index 3f8d92b9..bffc14fa 100644
--- a/candle-transformers/src/models/phi.rs
+++ b/candle-transformers/src/models/phi.rs
@@ -56,24 +56,20 @@ impl RotaryEmbedding {
.to_dtype(DType::F32)?
.reshape((cfg.max_position_embeddings, 1))?;
let freqs = t.matmul(&inv_freq)?;
- let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
Ok(Self {
dim,
- sin: emb.sin()?,
- cos: emb.cos()?,
+ sin: freqs.sin()?,
+ cos: freqs.cos()?,
})
}
fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (_b_size, _num_heads, seq_len, _headdim) = xs.dims4()?;
- let xs_rot = xs.i((.., .., .., ..self.dim))?;
+ let xs_rot = xs.i((.., .., .., ..self.dim))?.contiguous()?;
let xs_pass = xs.i((.., .., .., self.dim..))?;
- let xs12 = xs_rot.chunk(2, D::Minus1)?;
- let (xs1, xs2) = (&xs12[0], &xs12[1]);
let c = self.cos.narrow(0, seqlen_offset, seq_len)?;
let s = self.sin.narrow(0, seqlen_offset, seq_len)?;
- let rotate_half = Tensor::cat(&[&xs2.neg()?, xs1], D::Minus1)?;
- let xs_rot = (xs_rot.broadcast_mul(&c)? + rotate_half.broadcast_mul(&s)?)?;
+ let xs_rot = candle_nn::rotary_emb::rope(&xs_rot, &c, &s)?;
Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)
}
}