//! StableLM model implementation. //! //! StableLM is a family of language models trained by Stability AI. //! This implementation supports the StableLM architecture. //! //! Key characteristics: //! - Grouped query attention (GQA) //! - Layer normalization //! - Rotary positional embeddings (RoPE) //! - Support for different model sizes (3B, 7B) //! //! References: //! - [Model Card](https://huggingface.co/stabilityai/stablelm-3b-4e1t) //! use crate::models::with_tracing::{linear, linear_no_bias, Linear}; use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{Activation, LayerNorm, VarBuilder}; use serde::Deserialize; use std::sync::Arc; // https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/configuration_stablelm.py #[derive(Debug, Clone, PartialEq, Deserialize)] pub struct Config { pub(crate) vocab_size: usize, pub(crate) intermediate_size: usize, pub(crate) hidden_size: usize, pub(crate) num_hidden_layers: usize, pub(crate) num_attention_heads: usize, pub(crate) num_key_value_heads: usize, pub(crate) hidden_act: Activation, pub(crate) partial_rotary_factor: f64, pub(crate) rope_theta: f64, pub(crate) max_position_embeddings: usize, pub(crate) layer_norm_eps: f64, pub(crate) use_cache: bool, #[serde(default)] pub(crate) use_qkv_bias: bool, // Used in StableLM-2 #[serde(default)] pub(crate) use_flash_attn: bool, // Not in config.json } impl Config { pub fn stablelm_3b_4e1t(use_flash_attn: bool) -> Self { Self { vocab_size: 50304, intermediate_size: 6912, hidden_size: 2560, num_hidden_layers: 32, num_attention_heads: 32, num_key_value_heads: 32, hidden_act: Activation::Silu, partial_rotary_factor: 0.25, rope_theta: 10_000., max_position_embeddings: 4096, layer_norm_eps: 1e-5, use_qkv_bias: false, use_cache: true, use_flash_attn, } } pub fn head_dim(&self) -> usize { self.hidden_size / self.num_attention_heads } pub fn rotary_ndims(&self) -> usize { (self.head_dim() as f64 * self.partial_rotary_factor) as usize } pub fn num_kv_groups(&self) -> usize { self.num_attention_heads / self.num_key_value_heads } pub fn set_use_flash_attn(&mut self, use_flash_attn: bool) { self.use_flash_attn = use_flash_attn } } #[derive(Debug)] pub(crate) struct RotaryEmbedding { sin: Tensor, cos: Tensor, } fn rotate_half(xs: &Tensor) -> Result { let xs = xs.chunk(2, D::Minus1)?; Tensor::cat(&[&xs[1].neg()?, &xs[0]], D::Minus1) } impl RotaryEmbedding { pub(crate) fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { let dim = cfg.rotary_ndims(); let max_seq_len = cfg.max_position_embeddings; let inv_freq: Vec<_> = (0..dim) .step_by(2) .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32) .collect(); let inv_freq_len = inv_freq.len(); let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; let t = Tensor::arange(0u32, max_seq_len as u32, dev)? .to_dtype(dtype)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; Ok(Self { sin: freqs.sin()?, cos: freqs.cos()?, }) } pub(crate) fn apply_rotary_emb_qkv( &self, q: &Tensor, k: &Tensor, seqlen_offset: usize, ) -> Result<(Tensor, Tensor)> { let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; let cos = self.cos.narrow(0, seqlen_offset, seq_len)?; let sin = self.sin.narrow(0, seqlen_offset, seq_len)?; let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?; let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?; Ok((q_embed, k_embed)) } } #[derive(Debug)] #[allow(clippy::upper_case_acronyms)] struct MLP { gate_proj: Linear, up_proj: Linear, down_proj: Linear, act_fn: Activation, span: tracing::Span, } impl MLP { fn new(cfg: &Config, vb: VarBuilder) -> Result { let hidden_sz = cfg.hidden_size; let intermediate_sz = cfg.intermediate_size; let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?; let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?; let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?; Ok(Self { gate_proj, up_proj, down_proj, act_fn: cfg.hidden_act, span: tracing::span!(tracing::Level::TRACE, "mlp"), }) } } impl Module for MLP { fn forward(&self, xs: &Tensor) -> Result { let _enter = self.span.enter(); let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?; let rhs = xs.apply(&self.up_proj)?; (lhs * rhs)?.apply(&self.down_proj) } } #[cfg(feature = "flash-attn")] fn flash_attn( q: &Tensor, k: &Tensor, v: &Tensor, softmax_scale: f32, causal: bool, ) -> Result { candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) } #[cfg(not(feature = "flash-attn"))] fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result { unimplemented!("compile with '--features flash-attn'") } #[derive(Debug)] struct Attention { q_proj: Linear, k_proj: Linear, v_proj: Linear, o_proj: Linear, num_heads: usize, num_kv_heads: usize, num_kv_groups: usize, head_dim: usize, hidden_size: usize, rotary_emb: Arc, kv_cache: Option<(Tensor, Tensor)>, use_cache: bool, rotary_ndims: usize, use_flash_attn: bool, span: tracing::Span, } impl Attention { fn new(rotary_emb: Arc, cfg: &Config, vb: VarBuilder) -> Result { let hidden_sz = cfg.hidden_size; let head_dim = cfg.head_dim(); let num_heads = cfg.num_attention_heads; let num_kv_heads = cfg.num_key_value_heads; let linear_layer = if cfg.use_qkv_bias { linear } else { linear_no_bias }; let q_proj = linear_layer(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; let k_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; let v_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; Ok(Self { q_proj, k_proj, v_proj, o_proj, num_heads, num_kv_heads, num_kv_groups: cfg.num_kv_groups(), head_dim, hidden_size: hidden_sz, rotary_emb, kv_cache: None, use_cache: cfg.use_cache, rotary_ndims: cfg.rotary_ndims(), use_flash_attn: cfg.use_flash_attn, span: tracing::span!(tracing::Level::TRACE, "attn"), }) } fn forward( &mut self, xs: &Tensor, attention_mask: Option<&Tensor>, seqlen_offset: usize, ) -> Result { let _enter = self.span.enter(); let (b_sz, q_len, _) = xs.dims3()?; let query_states = self.q_proj.forward(xs)?; let key_states = self.k_proj.forward(xs)?; let value_states = self.v_proj.forward(xs)?; let query_states = query_states .reshape((b_sz, q_len, self.num_heads, self.head_dim))? .transpose(1, 2)?; let key_states = key_states .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? .transpose(1, 2)?; let value_states = value_states .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? .transpose(1, 2)?; let (rot_ndims, pass_ndims) = (self.rotary_ndims, self.head_dim - self.rotary_ndims); let query_rot = query_states.narrow(D::Minus1, 0, rot_ndims)?; let query_pass = query_states.narrow(D::Minus1, rot_ndims, pass_ndims)?; let key_rot = key_states.narrow(D::Minus1, 0, rot_ndims)?; let key_pass = key_states.narrow(D::Minus1, rot_ndims, pass_ndims)?; let (query_rot, key_rot) = self.rotary_emb .apply_rotary_emb_qkv(&query_rot, &key_rot, seqlen_offset)?; let query_states = Tensor::cat(&[query_rot, query_pass], D::Minus1)?.contiguous()?; let key_states = Tensor::cat(&[key_rot, key_pass], D::Minus1)?.contiguous()?; let (key_states, value_states) = match &self.kv_cache { None => (key_states, value_states), Some((prev_k, prev_v)) => { let key_states = Tensor::cat(&[prev_k, &key_states], 2)?; let value_states = Tensor::cat(&[prev_v, &value_states], 2)?; (key_states, value_states) } }; if self.use_cache { self.kv_cache = Some((key_states.clone(), value_states.clone())); } let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?; let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?; let attn_output = if self.use_flash_attn { // flash-attn expects (b_sz, seq_len, nheads, head_dim) let q = query_states.transpose(1, 2)?; let k = key_states.transpose(1, 2)?; let v = value_states.transpose(1, 2)?; let softmax_scale = 1f32 / (self.head_dim as f32).sqrt(); flash_attn(&q, &k, &v, softmax_scale, q_len > 1)?.transpose(1, 2)? } else { let scale = 1f64 / f64::sqrt(self.head_dim as f64); let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?; let attn_weights = match attention_mask { None => attn_weights, Some(mask) => attn_weights.broadcast_add(mask)?, }; let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; attn_weights.matmul(&value_states)? }; attn_output .transpose(1, 2)? .reshape((b_sz, q_len, self.hidden_size))? .apply(&self.o_proj) } } #[derive(Debug)] struct DecoderLayer { self_attn: Attention, mlp: MLP, input_layernorm: LayerNorm, post_attention_layernorm: LayerNorm, span: tracing::Span, } impl DecoderLayer { fn new(rotary_emb: Arc, cfg: &Config, vb: VarBuilder) -> Result { let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?; let mlp = MLP::new(cfg, vb.pp("mlp"))?; let input_layernorm = candle_nn::layer_norm( cfg.hidden_size, cfg.layer_norm_eps, vb.pp("input_layernorm"), )?; let post_attention_layernorm = candle_nn::layer_norm( cfg.hidden_size, cfg.layer_norm_eps, vb.pp("post_attention_layernorm"), )?; Ok(Self { self_attn, mlp, input_layernorm, post_attention_layernorm, span: tracing::span!(tracing::Level::TRACE, "layer"), }) } fn forward( &mut self, xs: &Tensor, attention_mask: Option<&Tensor>, seqlen_offset: usize, ) -> Result { let _enter = self.span.enter(); let residual = xs; let xs = self.input_layernorm.forward(xs)?; let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?; let xs = (xs + residual)?; let residual = &xs; let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?; residual + xs } } #[derive(Debug)] pub struct Model { embed_tokens: candle_nn::Embedding, layers: Vec, norm: LayerNorm, lm_head: Linear, device: Device, dtype: DType, span: tracing::Span, } impl Model { pub fn new(cfg: &Config, vb: VarBuilder) -> Result { let vb_m = vb.pp("model"); let embed_tokens = candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?); let mut layers = Vec::with_capacity(cfg.num_hidden_layers); let vb_l = vb_m.pp("layers"); for layer_idx in 0..cfg.num_hidden_layers { let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?; layers.push(layer) } let norm = candle_nn::layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb_m.pp("norm"))?; let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; Ok(Self { embed_tokens, layers, norm, lm_head, device: vb.device().clone(), dtype: vb.dtype(), span: tracing::span!(tracing::Level::TRACE, "model"), }) } fn prepare_decoder_attention_mask( &self, b_size: usize, tgt_len: usize, seqlen_offset: usize, ) -> Result { // Sliding window mask? let mask: Vec<_> = (0..tgt_len) .flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. })) .collect(); let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?; let mask = if seqlen_offset > 0 { let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?; Tensor::cat(&[&mask0, &mask], D::Minus1)? } else { mask }; mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))? .to_dtype(self.dtype) } pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result { let _enter = self.span.enter(); let (b_size, seq_len) = input_ids.dims2()?; let attention_mask = if seq_len <= 1 { None } else { let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?; Some(mask) }; let mut xs = self.embed_tokens.forward(input_ids)?; for layer in self.layers.iter_mut() { xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)? } xs.narrow(1, seq_len - 1, 1)? .apply(&self.norm)? .apply(&self.lm_head) } }