diff options
Diffstat (limited to 'candle-examples/examples/bigcode')
-rw-r--r-- | candle-examples/examples/bigcode/main.rs | 3 | ||||
-rw-r--r-- | candle-examples/examples/bigcode/model.rs | 359 |
2 files changed, 1 insertions, 361 deletions
diff --git a/candle-examples/examples/bigcode/main.rs b/candle-examples/examples/bigcode/main.rs index 652cd47f..3540f75d 100644 --- a/candle-examples/examples/bigcode/main.rs +++ b/candle-examples/examples/bigcode/main.rs @@ -7,8 +7,7 @@ extern crate accelerate_src; use anyhow::{Error as E, Result}; use clap::Parser; -mod model; -use model::{Config, GPTBigCode}; +use candle_transformers::models::bigcode::{Config, GPTBigCode}; use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; diff --git a/candle-examples/examples/bigcode/model.rs b/candle-examples/examples/bigcode/model.rs deleted file mode 100644 index 1e63956b..00000000 --- a/candle-examples/examples/bigcode/model.rs +++ /dev/null @@ -1,359 +0,0 @@ -use candle::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder}; - -fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> { - let weight = vb.get((size2, size1), "weight")?; - let bias = if bias { - Some(vb.get(size2, "bias")?) - } else { - None - }; - Ok(Linear::new(weight, bias)) -} - -fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { - let embeddings = vb.get((vocab_size, hidden_size), "weight")?; - Ok(Embedding::new(embeddings, hidden_size)) -} - -fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> { - let weight = vb.get(size, "weight")?; - let bias = vb.get(size, "bias")?; - Ok(LayerNorm::new(weight, bias, eps)) -} - -fn make_causal_mask(t: usize, device: &Device) -> Result<Tensor> { - let mask: Vec<_> = (0..t) - .flat_map(|i| (0..t).map(move |j| u8::from(j <= i))) - .collect(); - let mask = Tensor::from_slice(&mask, (t, t), device)?; - Ok(mask) -} - -#[derive(Debug)] -pub struct Config { - pub vocab_size: usize, - // max_position_embeddings aka n_positions - pub max_position_embeddings: usize, - // num_hidden_layers aka n_layer - pub num_hidden_layers: usize, - // hidden_size aka n_embd - pub hidden_size: usize, - pub layer_norm_epsilon: f64, - pub n_inner: Option<usize>, - // num_attention_heads aka n_head - pub num_attention_heads: usize, - pub multi_query: bool, - pub use_cache: bool, -} - -impl Config { - #[allow(dead_code)] - pub fn starcoder_1b() -> Self { - Self { - vocab_size: 49152, - max_position_embeddings: 8192, - num_hidden_layers: 24, - hidden_size: 2048, - layer_norm_epsilon: 1e-5, - n_inner: Some(8192), - num_attention_heads: 16, - multi_query: true, - use_cache: true, - } - } - - #[allow(dead_code)] - pub fn starcoder_3b() -> Self { - Self { - vocab_size: 49152, - max_position_embeddings: 8192, - num_hidden_layers: 36, - hidden_size: 2816, - layer_norm_epsilon: 1e-5, - n_inner: Some(11264), - num_attention_heads: 22, - multi_query: true, - use_cache: true, - } - } - - #[allow(dead_code)] - pub fn starcoder_7b() -> Self { - Self { - vocab_size: 49152, - max_position_embeddings: 8192, - num_hidden_layers: 42, - hidden_size: 4096, - layer_norm_epsilon: 1e-5, - n_inner: Some(16384), - num_attention_heads: 32, - multi_query: true, - use_cache: true, - } - } - - #[allow(dead_code)] - pub fn starcoder() -> Self { - Self { - vocab_size: 49152, - max_position_embeddings: 8192, - num_hidden_layers: 40, - hidden_size: 6144, - layer_norm_epsilon: 1e-5, - n_inner: Some(24576), - num_attention_heads: 48, - multi_query: true, - use_cache: true, - } - } -} - -struct Attention { - c_attn: Linear, - c_proj: Linear, - kv_cache: Option<Tensor>, - use_cache: bool, - embed_dim: usize, - kv_dim: usize, - num_heads: usize, - head_dim: usize, - multi_query: bool, -} - -impl Attention { - pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { - let hidden_size = cfg.hidden_size; - let head_dim = hidden_size / cfg.num_attention_heads; - let kv_heads = if cfg.multi_query { - 1 - } else { - cfg.num_attention_heads - }; - let kv_dim = kv_heads * head_dim; - let c_attn = linear(hidden_size, hidden_size + 2 * kv_dim, true, vb.pp("c_attn"))?; - let c_proj = linear(hidden_size, hidden_size, true, vb.pp("c_proj"))?; - Ok(Self { - c_proj, - c_attn, - embed_dim: hidden_size, - kv_cache: None, - use_cache: cfg.use_cache, - kv_dim, - head_dim, - num_heads: cfg.num_attention_heads, - multi_query: cfg.multi_query, - }) - } - - fn attn( - &self, - query: &Tensor, - key: &Tensor, - value: &Tensor, - attention_mask: &Tensor, - ) -> Result<Tensor> { - if query.dtype() != DType::F32 { - // If we start supporting f16 models, we may need the upcasting scaling bits. - // https://github.com/huggingface/transformers/blob/a0042379269bea9182c1f87e6b2eee4ba4c8cce8/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L133 - candle::bail!("upcasting is not supported {:?}", query.dtype()) - } - let scale_factor = 1f64 / (self.head_dim as f64).sqrt(); - let initial_query_shape = query.shape(); - let key_len = key.dim(D::Minus1)?; - let (query, key, attn_shape, attn_view) = if self.multi_query { - let (b_sz, query_len, _) = query.dims3()?; - let query = query.reshape((b_sz, query_len * self.num_heads, self.head_dim))?; - let attn_shape = (b_sz, query_len, self.num_heads, key_len); - let attn_view = (b_sz, query_len * self.num_heads, key_len); - (query, key.clone(), attn_shape, attn_view) - } else { - let (b_sz, _num_heads, query_len, _head_dim) = query.dims4()?; - let query = query.reshape((b_sz, query_len * self.num_heads, self.head_dim))?; - let key = key.reshape((b_sz * self.num_heads, self.head_dim, key_len))?; - let attn_shape = (b_sz, self.num_heads, query_len, key_len); - let attn_view = (b_sz * self.num_heads, query_len, key_len); - (query, key, attn_shape, attn_view) - }; - - let attn_weights = - (query.matmul(&key.contiguous()?)? * scale_factor)?.reshape(attn_shape)?; - let attention_mask = attention_mask.broadcast_as(attn_shape)?; - let mask_value = - Tensor::new(f32::NEG_INFINITY, query.device())?.broadcast_as(attn_shape)?; - let attn_weights = attention_mask.where_cond(&attn_weights, &mask_value)?; - let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?; - let value = value.contiguous()?; - let attn_output = if self.multi_query { - attn_weights - .reshape(attn_view)? - .matmul(&value)? - .reshape(initial_query_shape)? - } else { - attn_weights.matmul(&value)? - }; - Ok(attn_output) - } - - fn forward(&mut self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> { - let qkv = self.c_attn.forward(hidden_states)?; - let (query, key_value) = if self.multi_query { - let query = qkv.i((.., .., ..self.embed_dim))?; - let key_value = qkv.i((.., .., self.embed_dim..self.embed_dim + 2 * self.kv_dim))?; - (query, key_value) - } else { - let mut dims = qkv.dims().to_vec(); - dims.pop(); - dims.push(self.embed_dim); - dims.push(self.head_dim * 3); - let qkv = qkv.reshape(dims)?.transpose(1, 2)?; - let query = qkv.i((.., .., .., ..self.head_dim))?; - let key_value = qkv.i((.., .., .., self.head_dim..3 * self.head_dim))?; - (query, key_value) - }; - let mut key_value = key_value; - if self.use_cache { - if let Some(kv_cache) = &self.kv_cache { - // TODO: we could trim the tensors to MAX_SEQ_LEN so that this would work for - // arbitrarily large sizes. - key_value = Tensor::cat(&[kv_cache, &key_value], D::Minus2)?.contiguous()?; - } - self.kv_cache = Some(key_value.clone()) - } - - let key = key_value.narrow(D::Minus1, 0, self.head_dim)?; - let value = key_value.narrow(D::Minus1, self.head_dim, self.head_dim)?; - let attn_output = self.attn(&query, &key.t()?, &value, attention_mask)?; - let attn_output = if self.multi_query { - attn_output - } else { - attn_output - .transpose(1, 2)? - .reshape(hidden_states.shape())? - }; - let attn_output = self.c_proj.forward(&attn_output)?; - Ok(attn_output) - } -} - -struct Mlp { - c_fc: Linear, - c_proj: Linear, -} - -impl Mlp { - fn load(inner_dim: usize, vb: VarBuilder, cfg: &Config) -> Result<Self> { - let c_fc = linear(cfg.hidden_size, inner_dim, true, vb.pp("c_fc"))?; - let c_proj = linear(inner_dim, cfg.hidden_size, true, vb.pp("c_proj"))?; - Ok(Self { c_fc, c_proj }) - } - - fn forward(&mut self, hidden_states: &Tensor) -> Result<Tensor> { - let hidden_states = self.c_fc.forward(hidden_states)?.gelu()?; - let hidden_states = self.c_proj.forward(&hidden_states)?; - Ok(hidden_states) - } -} - -// TODO: Add cross-attention? -struct Block { - ln_1: LayerNorm, - attn: Attention, - ln_2: LayerNorm, - mlp: Mlp, -} - -impl Block { - fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { - let hidden_size = cfg.hidden_size; - let inner_dim = cfg.n_inner.unwrap_or(4 * hidden_size); - let ln_1 = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb.pp("ln_1"))?; - let attn = Attention::load(vb.pp("attn"), cfg)?; - let ln_2 = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb.pp("ln_2"))?; - let mlp = Mlp::load(inner_dim, vb.pp("mlp"), cfg)?; - Ok(Self { - ln_1, - attn, - ln_2, - mlp, - }) - } - - fn forward(&mut self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> { - let residual = hidden_states; - let hidden_states = self.ln_1.forward(hidden_states)?; - let attn_outputs = self.attn.forward(&hidden_states, attention_mask)?; - let hidden_states = (&attn_outputs + residual)?; - let residual = &hidden_states; - let hidden_states = self.ln_2.forward(&hidden_states)?; - let hidden_states = self.mlp.forward(&hidden_states)?; - let hidden_states = (&hidden_states + residual)?; - Ok(hidden_states) - } -} - -pub struct GPTBigCode { - wte: Embedding, - wpe: Embedding, - blocks: Vec<Block>, - ln_f: LayerNorm, - lm_head: Linear, - bias: Tensor, - config: Config, -} - -impl GPTBigCode { - pub fn config(&self) -> &Config { - &self.config - } - - pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> { - let hidden_size = cfg.hidden_size; - let vb_t = vb.pp("transformer"); - let wte = embedding(cfg.vocab_size, hidden_size, vb_t.pp("wte"))?; - let wpe = embedding(cfg.max_position_embeddings, hidden_size, vb_t.pp("wpe"))?; - let blocks = (0..cfg.num_hidden_layers) - .map(|i| Block::load(vb_t.pp(&format!("h.{i}")), &cfg)) - .collect::<Result<Vec<_>>>()?; - let ln_f = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb_t.pp("ln_f"))?; - let lm_head = linear(hidden_size, cfg.vocab_size, false, vb_t.pp("wte"))?; - let bias = make_causal_mask(cfg.max_position_embeddings, vb.device())?; - Ok(Self { - wte, - wpe, - blocks, - lm_head, - ln_f, - bias, - config: cfg, - }) - } - - pub fn forward(&mut self, input_ids: &Tensor, past_len: usize) -> Result<Tensor> { - let dev = input_ids.device(); - let (b_sz, seq_len) = input_ids.dims2()?; - - let key_len = past_len + seq_len; - let attention_mask = self.bias.i((past_len..key_len, ..key_len))?.unsqueeze(0)?; - // MQA models: (batch_size, query_length, n_heads, key_length) - // MHA models: (batch_size, n_heads, query_length, key_length) - let seq_len_dim = if self.config.multi_query { 2 } else { 1 }; - let attention_mask = attention_mask.unsqueeze(seq_len_dim)?; - - let position_ids = Tensor::arange(past_len as u32, (past_len + seq_len) as u32, dev)?; - let position_ids = position_ids.unsqueeze(0)?.broadcast_as((b_sz, seq_len))?; - let input_embeds = self.wte.forward(input_ids)?; - let position_embeds = self.wpe.forward(&position_ids)?; - - let mut hidden_states = (&input_embeds + &position_embeds)?; - for block in self.blocks.iter_mut() { - hidden_states = block.forward(&hidden_states, &attention_mask)?; - } - let hidden_states = self.ln_f.forward(&hidden_states)?; - let hidden_states = hidden_states - .reshape((b_sz, seq_len, self.config.hidden_size))? - .narrow(1, seq_len - 1, 1)?; - let logits = self.lm_head.forward(&hidden_states)?.squeeze(1)?; - Ok(logits) - } -} |