From d5f7267087bc253a2fe93c95ae78a164053646c1 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 6 Oct 2023 19:20:35 +0100 Subject: Add the stable-lm example. (#1046) * Add the stable-lm example. * Get stable-lm to generate some proper text. --- candle-transformers/src/models/stable_lm.rs | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) (limited to 'candle-transformers/src/models/stable_lm.rs') diff --git a/candle-transformers/src/models/stable_lm.rs b/candle-transformers/src/models/stable_lm.rs index 772c5ec9..e86f8877 100644 --- a/candle-transformers/src/models/stable_lm.rs +++ b/candle-transformers/src/models/stable_lm.rs @@ -148,6 +148,7 @@ struct Attention { rotary_emb: Arc, kv_cache: Option<(Tensor, Tensor)>, use_cache: bool, + rotary_ndims: usize, } impl Attention { @@ -173,6 +174,7 @@ impl Attention { rotary_emb, kv_cache: None, use_cache: cfg.use_cache, + rotary_ndims: cfg.rotary_ndims(), }) } @@ -210,9 +212,16 @@ impl Attention { .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? .transpose(1, 2)?; - let (query_states, key_states) = + let (rot_ndims, pass_ndims) = (self.rotary_ndims, self.head_dim - self.rotary_ndims); + let query_rot = query_states.narrow(D::Minus1, 0, rot_ndims)?; + let query_pass = query_states.narrow(D::Minus1, rot_ndims, pass_ndims)?; + let key_rot = key_states.narrow(D::Minus1, 0, rot_ndims)?; + let key_pass = key_states.narrow(D::Minus1, rot_ndims, pass_ndims)?; + let (query_rot, key_rot) = self.rotary_emb - .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?; + .apply_rotary_emb_qkv(&query_rot, &key_rot, seqlen_offset)?; + let query_states = Tensor::cat(&[query_rot, query_pass], D::Minus1)?.contiguous()?; + let key_states = Tensor::cat(&[key_rot, key_pass], D::Minus1)?.contiguous()?; let (key_states, value_states) = match &self.kv_cache { None => (key_states, value_states), @@ -226,8 +235,8 @@ impl Attention { self.kv_cache = Some((key_states.clone(), value_states.clone())); } - let key_states = self.repeat_kv(key_states)?; - let value_states = self.repeat_kv(value_states)?; + let key_states = self.repeat_kv(key_states)?.contiguous()?; + let value_states = self.repeat_kv(value_states)?.contiguous()?; let attn_output = { let scale = 1f64 / f64::sqrt(self.head_dim as f64); -- cgit v1.2.3