diff options
Diffstat (limited to 'candle-transformers/src/models/quantized_mistral.rs')
-rw-r--r-- | candle-transformers/src/models/quantized_mistral.rs | 20 |
1 files changed, 6 insertions, 14 deletions
diff --git a/candle-transformers/src/models/quantized_mistral.rs b/candle-transformers/src/models/quantized_mistral.rs index 2c5b7f74..e37785de 100644 --- a/candle-transformers/src/models/quantized_mistral.rs +++ b/candle-transformers/src/models/quantized_mistral.rs @@ -12,13 +12,6 @@ struct RotaryEmbedding { cos: Tensor, } -fn rotate_half(xs: &Tensor) -> Result<Tensor> { - let last_dim = xs.dim(D::Minus1)?; - let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?; - let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?; - Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1) -} - impl RotaryEmbedding { fn new(cfg: &Config, dev: &Device) -> Result<Self> { let rope_theta = cfg.rope_theta as f32; @@ -34,7 +27,6 @@ impl RotaryEmbedding { .to_dtype(DType::F32)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; - let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; Ok(Self { sin: freqs.sin()?, cos: freqs.cos()?, @@ -50,10 +42,8 @@ impl RotaryEmbedding { let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; let cos = self.cos.narrow(0, seqlen_offset, seq_len)?; let sin = self.sin.narrow(0, seqlen_offset, seq_len)?; - let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) - let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) - let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?; - let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?; + let q_embed = candle_nn::rotary_emb::rope(q, &cos, &sin)?; + let k_embed = candle_nn::rotary_emb::rope(k, &cos, &sin)?; Ok((q_embed, k_embed)) } } @@ -158,10 +148,12 @@ impl Attention { let query_states = query_states .reshape((b_sz, q_len, self.num_heads, self.head_dim))? - .transpose(1, 2)?; + .transpose(1, 2)? + .contiguous()?; let key_states = key_states .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? - .transpose(1, 2)?; + .transpose(1, 2)? + .contiguous()?; let value_states = value_states .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? .transpose(1, 2)?; |