diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-06 19:20:35 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-06 19:20:35 +0100 |
commit | d5f7267087bc253a2fe93c95ae78a164053646c1 (patch) | |
tree | 05e507c7130b9689675e69f17df5949b44367f53 /candle-transformers/src/models/stable_lm.rs | |
parent | 904bbdae65d69aac0c54c29eef744ca5e69c6733 (diff) | |
download | candle-d5f7267087bc253a2fe93c95ae78a164053646c1.tar.gz candle-d5f7267087bc253a2fe93c95ae78a164053646c1.tar.bz2 candle-d5f7267087bc253a2fe93c95ae78a164053646c1.zip |
Add the stable-lm example. (#1046)
* Add the stable-lm example.
* Get stable-lm to generate some proper text.
Diffstat (limited to 'candle-transformers/src/models/stable_lm.rs')
-rw-r--r-- | candle-transformers/src/models/stable_lm.rs | 17 |
1 files changed, 13 insertions, 4 deletions
diff --git a/candle-transformers/src/models/stable_lm.rs b/candle-transformers/src/models/stable_lm.rs index 772c5ec9..e86f8877 100644 --- a/candle-transformers/src/models/stable_lm.rs +++ b/candle-transformers/src/models/stable_lm.rs @@ -148,6 +148,7 @@ struct Attention { rotary_emb: Arc<RotaryEmbedding>, kv_cache: Option<(Tensor, Tensor)>, use_cache: bool, + rotary_ndims: usize, } impl Attention { @@ -173,6 +174,7 @@ impl Attention { rotary_emb, kv_cache: None, use_cache: cfg.use_cache, + rotary_ndims: cfg.rotary_ndims(), }) } @@ -210,9 +212,16 @@ impl Attention { .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? .transpose(1, 2)?; - let (query_states, key_states) = + let (rot_ndims, pass_ndims) = (self.rotary_ndims, self.head_dim - self.rotary_ndims); + let query_rot = query_states.narrow(D::Minus1, 0, rot_ndims)?; + let query_pass = query_states.narrow(D::Minus1, rot_ndims, pass_ndims)?; + let key_rot = key_states.narrow(D::Minus1, 0, rot_ndims)?; + let key_pass = key_states.narrow(D::Minus1, rot_ndims, pass_ndims)?; + let (query_rot, key_rot) = self.rotary_emb - .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?; + .apply_rotary_emb_qkv(&query_rot, &key_rot, seqlen_offset)?; + let query_states = Tensor::cat(&[query_rot, query_pass], D::Minus1)?.contiguous()?; + let key_states = Tensor::cat(&[key_rot, key_pass], D::Minus1)?.contiguous()?; let (key_states, value_states) = match &self.kv_cache { None => (key_states, value_states), @@ -226,8 +235,8 @@ impl Attention { self.kv_cache = Some((key_states.clone(), value_states.clone())); } - let key_states = self.repeat_kv(key_states)?; - let value_states = self.repeat_kv(value_states)?; + let key_states = self.repeat_kv(key_states)?.contiguous()?; + let value_states = self.repeat_kv(value_states)?.contiguous()?; let attn_output = { let scale = 1f64 / f64::sqrt(self.head_dim as f64); |