summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/stable_lm.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-06 19:20:35 +0100
committerGitHub <noreply@github.com>2023-10-06 19:20:35 +0100
commitd5f7267087bc253a2fe93c95ae78a164053646c1 (patch)
tree05e507c7130b9689675e69f17df5949b44367f53 /candle-transformers/src/models/stable_lm.rs
parent904bbdae65d69aac0c54c29eef744ca5e69c6733 (diff)
downloadcandle-d5f7267087bc253a2fe93c95ae78a164053646c1.tar.gz
candle-d5f7267087bc253a2fe93c95ae78a164053646c1.tar.bz2
candle-d5f7267087bc253a2fe93c95ae78a164053646c1.zip
Add the stable-lm example. (#1046)
* Add the stable-lm example. * Get stable-lm to generate some proper text.
Diffstat (limited to 'candle-transformers/src/models/stable_lm.rs')
-rw-r--r--candle-transformers/src/models/stable_lm.rs17
1 files changed, 13 insertions, 4 deletions
diff --git a/candle-transformers/src/models/stable_lm.rs b/candle-transformers/src/models/stable_lm.rs
index 772c5ec9..e86f8877 100644
--- a/candle-transformers/src/models/stable_lm.rs
+++ b/candle-transformers/src/models/stable_lm.rs
@@ -148,6 +148,7 @@ struct Attention {
rotary_emb: Arc<RotaryEmbedding>,
kv_cache: Option<(Tensor, Tensor)>,
use_cache: bool,
+ rotary_ndims: usize,
}
impl Attention {
@@ -173,6 +174,7 @@ impl Attention {
rotary_emb,
kv_cache: None,
use_cache: cfg.use_cache,
+ rotary_ndims: cfg.rotary_ndims(),
})
}
@@ -210,9 +212,16 @@ impl Attention {
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
- let (query_states, key_states) =
+ let (rot_ndims, pass_ndims) = (self.rotary_ndims, self.head_dim - self.rotary_ndims);
+ let query_rot = query_states.narrow(D::Minus1, 0, rot_ndims)?;
+ let query_pass = query_states.narrow(D::Minus1, rot_ndims, pass_ndims)?;
+ let key_rot = key_states.narrow(D::Minus1, 0, rot_ndims)?;
+ let key_pass = key_states.narrow(D::Minus1, rot_ndims, pass_ndims)?;
+ let (query_rot, key_rot) =
self.rotary_emb
- .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
+ .apply_rotary_emb_qkv(&query_rot, &key_rot, seqlen_offset)?;
+ let query_states = Tensor::cat(&[query_rot, query_pass], D::Minus1)?.contiguous()?;
+ let key_states = Tensor::cat(&[key_rot, key_pass], D::Minus1)?.contiguous()?;
let (key_states, value_states) = match &self.kv_cache {
None => (key_states, value_states),
@@ -226,8 +235,8 @@ impl Attention {
self.kv_cache = Some((key_states.clone(), value_states.clone()));
}
- let key_states = self.repeat_kv(key_states)?;
- let value_states = self.repeat_kv(value_states)?;
+ let key_states = self.repeat_kv(key_states)?.contiguous()?;
+ let value_states = self.repeat_kv(value_states)?.contiguous()?;
let attn_output = {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);