summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/mistral.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/mistral.rs')
-rw-r--r--candle-transformers/src/models/mistral.rs20
1 files changed, 6 insertions, 14 deletions
diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs
index 0e6200f5..d899c712 100644
--- a/candle-transformers/src/models/mistral.rs
+++ b/candle-transformers/src/models/mistral.rs
@@ -88,13 +88,6 @@ struct RotaryEmbedding {
cos: Tensor,
}
-fn rotate_half(xs: &Tensor) -> Result<Tensor> {
- let last_dim = xs.dim(D::Minus1)?;
- let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
- let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
- Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
-}
-
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let rope_theta = cfg.rope_theta as f32;
@@ -110,7 +103,6 @@ impl RotaryEmbedding {
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
- let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
@@ -126,10 +118,8 @@ impl RotaryEmbedding {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
- let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
- let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
- let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
- let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
+ let q_embed = candle_nn::rotary_emb::rope(q, &cos, &sin)?;
+ let k_embed = candle_nn::rotary_emb::rope(k, &cos, &sin)?;
Ok((q_embed, k_embed))
}
}
@@ -252,10 +242,12 @@ impl Attention {
let query_states = query_states
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
- .transpose(1, 2)?;
+ .transpose(1, 2)?
+ .contiguous()?;
let key_states = key_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
- .transpose(1, 2)?;
+ .transpose(1, 2)?
+ .contiguous()?;
let value_states = value_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;