summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-01-13 19:44:41 +0100
committerGitHub <noreply@github.com>2024-01-13 19:44:41 +0100
commit88618255cb3c20b511a2f0e6db35d84081ce3c4a (patch)
treeab9786cfd821101676fc53793e3eb440f131e600
parent539ead927a12a485637f7f04f8212cfdabe00fa4 (diff)
downloadcandle-88618255cb3c20b511a2f0e6db35d84081ce3c4a.tar.gz
candle-88618255cb3c20b511a2f0e6db35d84081ce3c4a.tar.bz2
candle-88618255cb3c20b511a2f0e6db35d84081ce3c4a.zip
Fix the rotary embeddings for the new phi implementation. (#1582)
* Fix the rotary embeddings for the new phi implementation. * Match the activation. * KV cache fix. * Use the config activation function.
-rw-r--r--candle-transformers/src/models/phi.rs34
1 files changed, 16 insertions, 18 deletions
diff --git a/candle-transformers/src/models/phi.rs b/candle-transformers/src/models/phi.rs
index a635f3ce..8bf357e7 100644
--- a/candle-transformers/src/models/phi.rs
+++ b/candle-transformers/src/models/phi.rs
@@ -38,6 +38,7 @@ impl Config {
#[derive(Debug, Clone)]
struct RotaryEmbedding {
+ dim: usize,
sin: Tensor,
cos: Tensor,
}
@@ -55,29 +56,24 @@ impl RotaryEmbedding {
.to_dtype(DType::F32)?
.reshape((cfg.max_position_embeddings, 1))?;
let freqs = t.matmul(&inv_freq)?;
+ let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
Ok(Self {
- sin: freqs.sin()?,
- cos: freqs.cos()?,
+ dim,
+ sin: emb.sin()?,
+ cos: emb.cos()?,
})
}
fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
- let (_b_size, seqlen, _, _headdim) = xs.dims4()?;
- let (_rotary_seqlen, rotary_dim) = self.cos.dims2()?;
- let rotary_dim = rotary_dim * 2;
- let xs_rot = xs.i((.., .., .., ..rotary_dim))?;
- let xs_pass = xs.i((.., .., .., rotary_dim..))?;
+ let (_b_size, _num_heads, seq_len, _headdim) = xs.dims4()?;
+ let xs_rot = xs.i((.., .., .., ..self.dim))?;
+ let xs_pass = xs.i((.., .., .., self.dim..))?;
let xs12 = xs_rot.chunk(2, D::Minus1)?;
let (xs1, xs2) = (&xs12[0], &xs12[1]);
- let c = self.cos.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?;
- let s = self.sin.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?;
- let xs_rot = Tensor::cat(
- &[
- (xs1.broadcast_mul(&c)? - xs2.broadcast_mul(&s)?)?,
- (xs1.broadcast_mul(&s)? + xs2.broadcast_mul(&c)?)?,
- ],
- D::Minus1,
- )?;
+ let c = self.cos.narrow(0, seqlen_offset, seq_len)?;
+ let s = self.sin.narrow(0, seqlen_offset, seq_len)?;
+ let rotate_half = Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)?;
+ let xs_rot = (xs_rot.broadcast_mul(&c)? + rotate_half.broadcast_mul(&s)?)?;
Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)
}
}
@@ -97,6 +93,8 @@ impl MLP {
Ok(Self {
fc1,
fc2,
+ // This does not match the mixformers implementation where Gelu is used rather than
+ // GeluNew.
act: cfg.hidden_act,
})
}
@@ -216,7 +214,7 @@ impl Attention {
// Rotary embeddings.
let seqlen_offset = match &self.kv_cache {
None => 0,
- Some((prev_k, _)) => prev_k.dim(1)?,
+ Some((prev_k, _)) => prev_k.dim(2)?,
};
let query_states = self
.rotary_emb
@@ -351,7 +349,7 @@ impl Model {
Some(get_mask(seq_len, xs.device())?)
};
for layer in self.layers.iter_mut() {
- xs = layer.forward(&xs, mask.as_ref())?
+ xs = layer.forward(&xs, mask.as_ref())?;
}
xs.apply(&self.final_layernorm)?
.narrow(1, seq_len - 1, 1)?