//! BigCode implementation in Rust based on the GPT-BigCode model. //! //! [StarCoder/BigCode](https://huggingface.co/bigcode/starcoderbase-1b) is a LLM //! model specialized to code generation. The initial model was trained on 80 //! programming languages. See "StarCoder: A State-of-the-Art LLM for Code", Mukherjee et al. 2023 //! - [Arxiv](https://arxiv.org/abs/2305.06161) //! - [Github](https://github.com/bigcode-project/starcoder) //! //! ## Running some example //! //! ```bash //! cargo run --example bigcode --release -- --prompt "fn fact(n: u64) -> u64" //! //! > fn fact(n: u64) -> u64 { //! > if n == 0 { //! > 1 //! > } else { //! > n * fact(n - 1) //! > } //! > } //! ``` //! use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{embedding, linear_b as linear, Embedding, LayerNorm, Linear, Module, VarBuilder}; fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> { let weight = vb.get(size, "weight")?; let bias = vb.get(size, "bias")?; Ok(LayerNorm::new(weight, bias, eps)) } fn make_causal_mask(t: usize, device: &Device) -> Result<Tensor> { let mask: Vec<_> = (0..t) .flat_map(|i| (0..t).map(move |j| u8::from(j <= i))) .collect(); let mask = Tensor::from_slice(&mask, (t, t), device)?; Ok(mask) } #[derive(Debug)] pub struct Config { pub vocab_size: usize, // max_position_embeddings aka n_positions pub max_position_embeddings: usize, // num_hidden_layers aka n_layer pub num_hidden_layers: usize, // hidden_size aka n_embd pub hidden_size: usize, pub layer_norm_epsilon: f64, pub n_inner: Option<usize>, // num_attention_heads aka n_head pub num_attention_heads: usize, pub multi_query: bool, pub use_cache: bool, } impl Config { #[allow(dead_code)] pub fn starcoder_1b() -> Self { Self { vocab_size: 49152, max_position_embeddings: 8192, num_hidden_layers: 24, hidden_size: 2048, layer_norm_epsilon: 1e-5, n_inner: Some(8192), num_attention_heads: 16, multi_query: true, use_cache: true, } } #[allow(dead_code)] pub fn starcoder_3b() -> Self { Self { vocab_size: 49152, max_position_embeddings: 8192, num_hidden_layers: 36, hidden_size: 2816, layer_norm_epsilon: 1e-5, n_inner: Some(11264), num_attention_heads: 22, multi_query: true, use_cache: true, } } #[allow(dead_code)] pub fn starcoder_7b() -> Self { Self { vocab_size: 49152, max_position_embeddings: 8192, num_hidden_layers: 42, hidden_size: 4096, layer_norm_epsilon: 1e-5, n_inner: Some(16384), num_attention_heads: 32, multi_query: true, use_cache: true, } } #[allow(dead_code)] pub fn starcoder() -> Self { Self { vocab_size: 49152, max_position_embeddings: 8192, num_hidden_layers: 40, hidden_size: 6144, layer_norm_epsilon: 1e-5, n_inner: Some(24576), num_attention_heads: 48, multi_query: true, use_cache: true, } } } struct Attention { c_attn: Linear, c_proj: Linear, kv_cache: Option<Tensor>, use_cache: bool, embed_dim: usize, kv_dim: usize, num_heads: usize, head_dim: usize, multi_query: bool, } impl Attention { pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { let hidden_size = cfg.hidden_size; let head_dim = hidden_size / cfg.num_attention_heads; let kv_heads = if cfg.multi_query { 1 } else { cfg.num_attention_heads }; let kv_dim = kv_heads * head_dim; let c_attn = linear(hidden_size, hidden_size + 2 * kv_dim, true, vb.pp("c_attn"))?; let c_proj = linear(hidden_size, hidden_size, true, vb.pp("c_proj"))?; Ok(Self { c_proj, c_attn, embed_dim: hidden_size, kv_cache: None, use_cache: cfg.use_cache, kv_dim, head_dim, num_heads: cfg.num_attention_heads, multi_query: cfg.multi_query, }) } fn attn( &self, query: &Tensor, key: &Tensor, value: &Tensor, attention_mask: &Tensor, ) -> Result<Tensor> { if query.dtype() != DType::F32 { // If we start supporting f16 models, we may need the upcasting scaling bits. // https://github.com/huggingface/transformers/blob/a0042379269bea9182c1f87e6b2eee4ba4c8cce8/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L133 candle::bail!("upcasting is not supported {:?}", query.dtype()) } let scale_factor = 1f64 / (self.head_dim as f64).sqrt(); let initial_query_shape = query.shape(); let key_len = key.dim(D::Minus1)?; let (query, key, attn_shape, attn_view) = if self.multi_query { let (b_sz, query_len, _) = query.dims3()?; let query = query.reshape((b_sz, query_len * self.num_heads, self.head_dim))?; let attn_shape = (b_sz, query_len, self.num_heads, key_len); let attn_view = (b_sz, query_len * self.num_heads, key_len); (query, key.clone(), attn_shape, attn_view) } else { let (b_sz, _num_heads, query_len, _head_dim) = query.dims4()?; let query = query.reshape((b_sz, query_len * self.num_heads, self.head_dim))?; let key = key.reshape((b_sz * self.num_heads, self.head_dim, key_len))?; let attn_shape = (b_sz, self.num_heads, query_len, key_len); let attn_view = (b_sz * self.num_heads, query_len, key_len); (query, key, attn_shape, attn_view) }; let attn_weights = (query.matmul(&key.contiguous()?)? * scale_factor)?.reshape(attn_shape)?; let attention_mask = attention_mask.broadcast_as(attn_shape)?; let mask_value = Tensor::new(f32::NEG_INFINITY, query.device())?.broadcast_as(attn_shape)?; let attn_weights = attention_mask.where_cond(&attn_weights, &mask_value)?; let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; let value = value.contiguous()?; let attn_output = if self.multi_query { attn_weights .reshape(attn_view)? .matmul(&value)? .reshape(initial_query_shape)? } else { attn_weights.matmul(&value)? }; Ok(attn_output) } fn forward(&mut self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> { let qkv = self.c_attn.forward(hidden_states)?; let (query, key_value) = if self.multi_query { let query = qkv.i((.., .., ..self.embed_dim))?; let key_value = qkv.i((.., .., self.embed_dim..self.embed_dim + 2 * self.kv_dim))?; (query, key_value) } else { let mut dims = qkv.dims().to_vec(); dims.pop(); dims.push(self.embed_dim); dims.push(self.head_dim * 3); let qkv = qkv.reshape(dims)?.transpose(1, 2)?; let query = qkv.i((.., .., .., ..self.head_dim))?; let key_value = qkv.i((.., .., .., self.head_dim..3 * self.head_dim))?; (query, key_value) }; let mut key_value = key_value; if self.use_cache { if let Some(kv_cache) = &self.kv_cache { // TODO: we could trim the tensors to MAX_SEQ_LEN so that this would work for // arbitrarily large sizes. key_value = Tensor::cat(&[kv_cache, &key_value], D::Minus2)?.contiguous()?; } self.kv_cache = Some(key_value.clone()) } let key = key_value.narrow(D::Minus1, 0, self.head_dim)?; let value = key_value.narrow(D::Minus1, self.head_dim, self.head_dim)?; let attn_output = self.attn(&query, &key.t()?, &value, attention_mask)?; let attn_output = if self.multi_query { attn_output } else { attn_output .transpose(1, 2)? .reshape(hidden_states.shape())? }; let attn_output = self.c_proj.forward(&attn_output)?; Ok(attn_output) } } struct Mlp { c_fc: Linear, c_proj: Linear, } impl Mlp { fn load(inner_dim: usize, vb: VarBuilder, cfg: &Config) -> Result<Self> { let c_fc = linear(cfg.hidden_size, inner_dim, true, vb.pp("c_fc"))?; let c_proj = linear(inner_dim, cfg.hidden_size, true, vb.pp("c_proj"))?; Ok(Self { c_fc, c_proj }) } fn forward(&mut self, hidden_states: &Tensor) -> Result<Tensor> { let hidden_states = self.c_fc.forward(hidden_states)?.gelu()?; let hidden_states = self.c_proj.forward(&hidden_states)?; Ok(hidden_states) } } // TODO: Add cross-attention? struct Block { ln_1: LayerNorm, attn: Attention, ln_2: LayerNorm, mlp: Mlp, } impl Block { fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { let hidden_size = cfg.hidden_size; let inner_dim = cfg.n_inner.unwrap_or(4 * hidden_size); let ln_1 = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb.pp("ln_1"))?; let attn = Attention::load(vb.pp("attn"), cfg)?; let ln_2 = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb.pp("ln_2"))?; let mlp = Mlp::load(inner_dim, vb.pp("mlp"), cfg)?; Ok(Self { ln_1, attn, ln_2, mlp, }) } fn forward(&mut self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> { let residual = hidden_states; let hidden_states = self.ln_1.forward(hidden_states)?; let attn_outputs = self.attn.forward(&hidden_states, attention_mask)?; let hidden_states = (&attn_outputs + residual)?; let residual = &hidden_states; let hidden_states = self.ln_2.forward(&hidden_states)?; let hidden_states = self.mlp.forward(&hidden_states)?; let hidden_states = (&hidden_states + residual)?; Ok(hidden_states) } } pub struct GPTBigCode { wte: Embedding, wpe: Embedding, blocks: Vec<Block>, ln_f: LayerNorm, lm_head: Linear, bias: Tensor, config: Config, } impl GPTBigCode { pub fn config(&self) -> &Config { &self.config } pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> { let hidden_size = cfg.hidden_size; let vb_t = vb.pp("transformer"); let wte = embedding(cfg.vocab_size, hidden_size, vb_t.pp("wte"))?; let wpe = embedding(cfg.max_position_embeddings, hidden_size, vb_t.pp("wpe"))?; let blocks = (0..cfg.num_hidden_layers) .map(|i| Block::load(vb_t.pp(format!("h.{i}")), &cfg)) .collect::<Result<Vec<_>>>()?; let ln_f = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb_t.pp("ln_f"))?; let lm_head = linear(hidden_size, cfg.vocab_size, false, vb_t.pp("wte"))?; let bias = make_causal_mask(cfg.max_position_embeddings, vb.device())?; Ok(Self { wte, wpe, blocks, lm_head, ln_f, bias, config: cfg, }) } pub fn forward(&mut self, input_ids: &Tensor, past_len: usize) -> Result<Tensor> { let dev = input_ids.device(); let (b_sz, seq_len) = input_ids.dims2()?; let key_len = past_len + seq_len; let attention_mask = self.bias.i((past_len..key_len, ..key_len))?.unsqueeze(0)?; // MQA models: (batch_size, query_length, n_heads, key_length) // MHA models: (batch_size, n_heads, query_length, key_length) let seq_len_dim = if self.config.multi_query { 2 } else { 1 }; let attention_mask = attention_mask.unsqueeze(seq_len_dim)?; let position_ids = Tensor::arange(past_len as u32, (past_len + seq_len) as u32, dev)?; let position_ids = position_ids.unsqueeze(0)?.broadcast_as((b_sz, seq_len))?; let input_embeds = self.wte.forward(input_ids)?; let position_embeds = self.wpe.forward(&position_ids)?; let mut hidden_states = (&input_embeds + &position_embeds)?; for block in self.blocks.iter_mut() { hidden_states = block.forward(&hidden_states, &attention_mask)?; } let hidden_states = self.ln_f.forward(&hidden_states)?; let hidden_states = hidden_states .reshape((b_sz, seq_len, self.config.hidden_size))? .narrow(1, seq_len - 1, 1)?; let logits = self.lm_head.forward(&hidden_states)?.squeeze(1)?; Ok(logits) } }