diff options
Diffstat (limited to 'candle-examples/examples/falcon/model.rs')
-rw-r--r-- | candle-examples/examples/falcon/model.rs | 485 |
1 files changed, 0 insertions, 485 deletions
diff --git a/candle-examples/examples/falcon/model.rs b/candle-examples/examples/falcon/model.rs deleted file mode 100644 index b638dd51..00000000 --- a/candle-examples/examples/falcon/model.rs +++ /dev/null @@ -1,485 +0,0 @@ -use anyhow::Result; -use candle::{DType, Device, Tensor, D}; -use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder}; - -const MAX_SEQ_LEN: usize = 5000; - -fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> { - let weight = vb.get((size2, size1), "weight")?; - let bias = if bias { - Some(vb.get(size2, "bias")?) - } else { - None - }; - Ok(Linear::new(weight, bias)) -} - -fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> { - let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) { - (Ok(weight), Ok(bias)) => (weight, bias), - (Err(err), _) | (_, Err(err)) => { - if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) { - (weight, bias) - } else { - return Err(err.into()); - } - } - }; - Ok(LayerNorm::new(weight, bias, eps)) -} - -fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { - let embeddings = vb.get((vocab_size, hidden_size), "weight")?; - Ok(Embedding::new(embeddings, hidden_size)) -} - -// https://raw.githubusercontent.com/huggingface/transformers/030c863aaa0165e98352b61697430bf69bf33755/src/transformers/models/falcon/configuration_falcon.py -#[derive(Debug)] -pub struct Config { - pub vocab_size: usize, - pub hidden_size: usize, - pub num_hidden_layers: usize, - pub num_attention_heads: usize, - pub layer_norm_epsilon: f64, - pub initializer_range: f64, - pub use_cache: bool, - pub bos_token_id: u32, - pub eos_token_id: u32, - pub hidden_dropout: f64, - pub attention_dropout: f64, - pub n_head_kv: Option<usize>, - pub alibi: bool, - pub new_decoder_architecture: bool, - pub multi_query: bool, - pub parallel_attn: bool, - pub bias: bool, -} - -impl Default for Config { - fn default() -> Self { - Self { - vocab_size: 65024, - hidden_size: 4544, - num_hidden_layers: 32, - num_attention_heads: 71, - layer_norm_epsilon: 1e-5, - initializer_range: 0.02, - use_cache: true, - bos_token_id: 11, - eos_token_id: 11, - hidden_dropout: 0.0, - attention_dropout: 0.0, - n_head_kv: None, - alibi: false, - new_decoder_architecture: false, - multi_query: true, - parallel_attn: true, - bias: false, - } - } -} - -impl Config { - pub fn validate(&self) -> Result<()> { - if self.alibi { - anyhow::bail!("alibi is not supported"); - } - if self.new_decoder_architecture { - anyhow::bail!("new_decoder_architecture is not supported"); - } - if self.n_head_kv.is_some() { - anyhow::bail!("n_head_kv is not supported"); - } - Ok(()) - } - - // https://huggingface.co/tiiuae/falcon-7b/blob/main/config.json - pub fn falcon7b() -> Self { - // This is currently on par with the defaults, the defaults come from the Python default - // arguments for the config initialization whereas the following come from the json config. - Self { - vocab_size: 65024, - hidden_size: 4544, - num_hidden_layers: 32, - num_attention_heads: 71, - layer_norm_epsilon: 1e-5, - initializer_range: 0.02, - use_cache: true, - bos_token_id: 11, - eos_token_id: 11, - hidden_dropout: 0., - attention_dropout: 0., - n_head_kv: None, - alibi: false, - new_decoder_architecture: false, - multi_query: true, - parallel_attn: true, - bias: false, - } - } - - fn head_dim(&self) -> usize { - self.hidden_size / self.num_attention_heads - } - - fn rotary(&self) -> bool { - !self.alibi - } -} - -fn rotate_half(x: &Tensor) -> Result<Tensor> { - let l = x.dim(D::Minus1)?; - let x1 = x.narrow(D::Minus1, 0, l / 2)?; - let x2 = x.narrow(D::Minus1, l / 2, l - l / 2)?; - let x21 = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?; - Ok(x21) -} - -#[derive(Debug)] -struct FalconRotaryEmbedding { - inv_freq: Tensor, - cache: Option<(usize, Tensor, Tensor)>, -} - -impl FalconRotaryEmbedding { - fn load(device: &Device, cfg: &Config) -> Result<Self> { - let head_dim = cfg.head_dim(); - let inv_freq: Vec<_> = (0..head_dim) - .step_by(2) - .map(|i| 1f32 / 10000f32.powf(i as f32 / head_dim as f32)) - .collect(); - Ok(Self { - inv_freq: Tensor::new(inv_freq.as_slice(), device)?, - cache: None, - }) - } - - fn cos_sin( - &mut self, - seq_len: usize, - device: &Device, - dtype: DType, - ) -> Result<(Tensor, Tensor)> { - match &self.cache { - Some((s, cos, sin)) if *s == seq_len => { - return Ok((cos.clone(), sin.clone())); - } - _ => {} - } - let t = Tensor::arange(0, seq_len as u32, device)?.to_dtype(dtype)?; - let inv_freq = self.inv_freq.to_dtype(dtype)?; - let freqs = t.unsqueeze(1)?.matmul(&inv_freq.unsqueeze(0)?)?; - let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; - let cos = emb.cos()?; - let sin = emb.sin()?; - self.cache = Some((seq_len, cos.clone(), sin.clone())); - Ok((cos, sin)) - } - - fn forward( - &mut self, - query: &Tensor, - key: &Tensor, - past_kv_len: usize, - ) -> Result<(Tensor, Tensor)> { - let (_batch, seq_len, _head_dim) = query.dims3()?; - let (cos, sin) = self.cos_sin(MAX_SEQ_LEN, query.device(), query.dtype())?; - let cos = cos.narrow(0, past_kv_len, seq_len)?; - let sin = sin.narrow(0, past_kv_len, seq_len)?; - let qs = (query.broadcast_mul(&cos)? + &rotate_half(query)?.broadcast_mul(&sin)?)?; - let ks = (key.broadcast_mul(&cos)? + &rotate_half(key)?.broadcast_mul(&sin)?)?; - Ok((qs, ks)) - } -} - -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { - let shape = mask.shape(); - let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; - let m = mask.where_cond(&on_true, on_false)?; - Ok(m) -} - -#[derive(Debug)] -struct FalconAttention { - query_key_value: Linear, - dense: Linear, - maybe_rotary: Option<FalconRotaryEmbedding>, - kv_cache: Option<(Tensor, Tensor)>, - inv_norm_factor: f64, - multi_query: bool, - use_cache: bool, - num_heads: usize, - head_dim: usize, - n_head_kv: usize, -} - -impl FalconAttention { - fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { - let maybe_rotary = if cfg.rotary() { - let rotary = FalconRotaryEmbedding::load(vb.device(), cfg)?; - Some(rotary) - } else { - None - }; - let head_dim = cfg.head_dim(); - let hidden_size = cfg.hidden_size; - let qkv_out_dim = if cfg.multi_query { - hidden_size + 2 * head_dim - } else { - 3 * hidden_size - }; - let query_key_value = linear(hidden_size, qkv_out_dim, cfg.bias, vb.pp("query_key_value"))?; - let dense = linear(hidden_size, hidden_size, cfg.bias, vb.pp("dense"))?; - Ok(Self { - query_key_value, - dense, - maybe_rotary, - kv_cache: None, - inv_norm_factor: 1. / (head_dim as f64).sqrt(), - multi_query: cfg.multi_query, - use_cache: cfg.use_cache, - num_heads: cfg.num_attention_heads, - n_head_kv: cfg.n_head_kv.unwrap_or(1), - head_dim, - }) - } - - fn split_heads(&self, fused_qkv: &Tensor) -> Result<(Tensor, Tensor, Tensor)> { - let (b_sz, seq_len, _) = fused_qkv.dims3()?; - if !self.multi_query { - let fused_qkv = fused_qkv.reshape((b_sz, seq_len, self.num_heads, 3, self.head_dim))?; - let q = fused_qkv.narrow(D::Minus2, 0, 1)?.squeeze(D::Minus2)?; - let k = fused_qkv.narrow(D::Minus2, 1, 1)?.squeeze(D::Minus2)?; - let v = fused_qkv.narrow(D::Minus2, 2, 1)?.squeeze(D::Minus2)?; - Ok((q, k, v)) - } else { - let fused_qkv = - fused_qkv.reshape((b_sz, seq_len, self.num_heads + 2, self.head_dim))?; - let d = fused_qkv.dim(D::Minus2)?; - let q = fused_qkv.narrow(D::Minus2, 0, d - 2)?; - let k = fused_qkv.narrow(D::Minus2, d - 2, 1)?; - let v = fused_qkv.narrow(D::Minus2, d - 1, 1)?; - Ok((q, k, v)) - } - } - - fn forward(&mut self, x: &Tensor, mask: &Tensor, past_kv_len: usize) -> Result<Tensor> { - let fused_qkv = self.query_key_value.forward(x)?; - let head_dim = self.head_dim; - let (query, key, value) = self.split_heads(&fused_qkv)?; - let (b_sz, seq_len, _, _) = query.dims4()?; - let query = query - .transpose(1, 2)? - .reshape((b_sz * self.num_heads, seq_len, head_dim))?; - let key = key - .transpose(1, 2)? - .reshape((b_sz * self.n_head_kv, seq_len, head_dim))?; - let value = value - .transpose(1, 2)? - .reshape((b_sz * self.n_head_kv, seq_len, head_dim))?; - let (query, key) = if let Some(r) = &mut self.maybe_rotary { - r.forward(&query, &key, past_kv_len)? - } else { - (query, key) - }; - let (mut key, mut value) = (key, value); - let mask = masked_fill(&mask.to_dtype(DType::F32)?, mask, -1e9)?.to_dtype(query.dtype())?; - if self.use_cache { - if let Some((cache_k, cache_v)) = &self.kv_cache { - // TODO: we could trim the tensors to MAX_SEQ_LEN so that this would work for - // arbitrarily large sizes. - key = Tensor::cat(&[cache_k, &key], 1)?.contiguous()?; - value = Tensor::cat(&[cache_v, &value], 1)?.contiguous()?; - } - self.kv_cache = Some((key.clone(), value.clone())) - } - let query = query.reshape((b_sz * self.num_heads, seq_len, head_dim))?; - let all_len = past_kv_len + seq_len; - let key = key.reshape((b_sz * self.n_head_kv, all_len, head_dim))?; - let value = value.reshape((b_sz * self.n_head_kv, all_len, head_dim))?; - - let (key, value) = if self.n_head_kv == 1 { - ( - key.broadcast_as((b_sz * self.num_heads, all_len, head_dim))?, - value.broadcast_as((b_sz * self.num_heads, all_len, head_dim))?, - ) - } else { - (key, value) - }; - - // Only handle the case where alibi is None here, and non-flash attention. - let attention_scores = (query.matmul(&key.t()?)? * self.inv_norm_factor)?; - let attention_scores = candle_nn::ops::softmax( - &attention_scores - .broadcast_add(&mask.squeeze(1)?)? - .to_dtype(DType::F32)?, - D::Minus1, - )? - .to_dtype(x.dtype())?; - let attn_output = attention_scores - .matmul(&value)? - .reshape((b_sz, self.num_heads, seq_len, head_dim))? - .transpose(1, 2)? - .reshape((b_sz, seq_len, self.num_heads * head_dim))?; - let attn_output = self.dense.forward(&attn_output)?; - Ok(attn_output) - } -} - -#[derive(Debug)] -struct FalconMlp { - dense_h_to_4h: Linear, - dense_4h_to_h: Linear, -} - -impl FalconMlp { - fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { - let h = cfg.hidden_size; - let b = cfg.bias; - let dense_h_to_4h = linear(h, 4 * h, b, vb.pp("dense_h_to_4h"))?; - let dense_4h_to_h = linear(4 * h, h, b, vb.pp("dense_4h_to_h"))?; - Ok(Self { - dense_h_to_4h, - dense_4h_to_h, - }) - } - - fn forward(&self, x: &Tensor) -> Result<Tensor> { - let x = self.dense_h_to_4h.forward(x)?.gelu()?; - let x = self.dense_4h_to_h.forward(&x)?; - Ok(x) - } -} - -#[derive(Debug)] -struct FalconDecoderLayer { - inp_layernorm: LayerNorm, - self_attention: FalconAttention, - post_attention_layernorm: Option<LayerNorm>, - mlp: FalconMlp, - parallel_attn: bool, -} - -impl FalconDecoderLayer { - fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { - let mlp = FalconMlp::load(vb.pp("mlp"), cfg)?; - let inp_layernorm = layer_norm( - cfg.hidden_size, - cfg.layer_norm_epsilon, - vb.pp("input_layernorm"), - )?; - let self_attention = FalconAttention::load(vb.pp("self_attention"), cfg)?; - let post_attention_layernorm = if cfg.parallel_attn { - None - } else { - let ln = layer_norm( - cfg.hidden_size, - cfg.layer_norm_epsilon, - vb.pp("post_attention_layernorm"), - )?; - Some(ln) - }; - Ok(Self { - inp_layernorm, - self_attention, - post_attention_layernorm, - mlp, - parallel_attn: cfg.parallel_attn, - }) - } - - fn forward(&mut self, x: &Tensor, mask: &Tensor, past_kv_len: usize) -> Result<Tensor> { - let residual = x.clone(); - let ln_attn = self.inp_layernorm.forward(x)?; - let attn_output = self.self_attention.forward(&ln_attn, mask, past_kv_len)?; - let (residual, ln_mlp) = match &self.post_attention_layernorm { - None => (residual, ln_attn), - Some(pal) => { - // This should include some dropout. - let residual = (&attn_output + &residual)?; - let ln_mlp = pal.forward(&residual)?; - (residual, ln_mlp) - } - }; - let mlp_output = self.mlp.forward(&ln_mlp)?; - - let mlp_output = if self.parallel_attn { - (mlp_output + attn_output)? - } else { - mlp_output - }; - let output = (mlp_output + residual)?; - Ok(output) - } -} - -#[derive(Debug)] -pub struct Falcon { - word_embeddings: Embedding, - blocks: Vec<FalconDecoderLayer>, - ln_f: LayerNorm, - lm_head: Linear, - config: Config, -} - -fn make_causal_mask(t: usize) -> Result<Tensor> { - let mask: Vec<_> = (0..t) - .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) - .collect(); - let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?; - Ok(mask) -} - -fn prepare_attn_mask(b_sz: usize, seq_len: usize) -> Result<Tensor> { - // let mask = Tensor::ones((b_sz, seq_len), DType::U32, &Device::Cpu)?; - let mask = make_causal_mask(seq_len)?; - let mask = mask.broadcast_as((b_sz, 1, seq_len, seq_len))?; - Ok(mask) -} - -impl Falcon { - pub fn config(&self) -> &Config { - &self.config - } - - pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> { - let word_embeddings = embedding( - cfg.vocab_size, - cfg.hidden_size, - vb.pp("transformer.word_embeddings"), - )?; - let blocks = (0..cfg.num_hidden_layers) - .map(|i| FalconDecoderLayer::load(vb.pp(&format!("transformer.h.{i}")), &cfg)) - .collect::<Result<Vec<_>>>()?; - let ln_f = layer_norm( - cfg.hidden_size, - cfg.layer_norm_epsilon, - vb.pp("transformer.ln_f"), - )?; - let lm_head = linear(cfg.hidden_size, cfg.vocab_size, false, vb.pp("lm_head"))?; - Ok(Self { - word_embeddings, - blocks, - ln_f, - lm_head, - config: cfg, - }) - } - - pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> { - let (b_sz, seq_len) = input_ids.dims2()?; - let mut hidden_state = self.word_embeddings.forward(input_ids)?; - let past_kv_len = match &self.blocks[0].self_attention.kv_cache { - Some((k, _)) => k.dim(1)?, - None => 0, - }; - let causal_mask = prepare_attn_mask(b_sz, seq_len)?.to_device(input_ids.device())?; - for block in self.blocks.iter_mut() { - hidden_state = block.forward(&hidden_state, &causal_mask, past_kv_len)?; - } - let hidden_state = self.ln_f.forward(&hidden_state)?; - let hidden_state = hidden_state.narrow(1, seq_len - 1, 1)?; - let logits = self.lm_head.forward(&hidden_state)?.squeeze(1)?; - Ok(logits) - } -} |