summaryrefslogtreecommitdiff
path: root/candle-examples/examples/falcon/model.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/falcon/model.rs')
-rw-r--r--candle-examples/examples/falcon/model.rs485
1 files changed, 0 insertions, 485 deletions
diff --git a/candle-examples/examples/falcon/model.rs b/candle-examples/examples/falcon/model.rs
deleted file mode 100644
index b638dd51..00000000
--- a/candle-examples/examples/falcon/model.rs
+++ /dev/null
@@ -1,485 +0,0 @@
-use anyhow::Result;
-use candle::{DType, Device, Tensor, D};
-use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
-
-const MAX_SEQ_LEN: usize = 5000;
-
-fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
- let weight = vb.get((size2, size1), "weight")?;
- let bias = if bias {
- Some(vb.get(size2, "bias")?)
- } else {
- None
- };
- Ok(Linear::new(weight, bias))
-}
-
-fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
- let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
- (Ok(weight), Ok(bias)) => (weight, bias),
- (Err(err), _) | (_, Err(err)) => {
- if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
- (weight, bias)
- } else {
- return Err(err.into());
- }
- }
- };
- Ok(LayerNorm::new(weight, bias, eps))
-}
-
-fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
- let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
- Ok(Embedding::new(embeddings, hidden_size))
-}
-
-// https://raw.githubusercontent.com/huggingface/transformers/030c863aaa0165e98352b61697430bf69bf33755/src/transformers/models/falcon/configuration_falcon.py
-#[derive(Debug)]
-pub struct Config {
- pub vocab_size: usize,
- pub hidden_size: usize,
- pub num_hidden_layers: usize,
- pub num_attention_heads: usize,
- pub layer_norm_epsilon: f64,
- pub initializer_range: f64,
- pub use_cache: bool,
- pub bos_token_id: u32,
- pub eos_token_id: u32,
- pub hidden_dropout: f64,
- pub attention_dropout: f64,
- pub n_head_kv: Option<usize>,
- pub alibi: bool,
- pub new_decoder_architecture: bool,
- pub multi_query: bool,
- pub parallel_attn: bool,
- pub bias: bool,
-}
-
-impl Default for Config {
- fn default() -> Self {
- Self {
- vocab_size: 65024,
- hidden_size: 4544,
- num_hidden_layers: 32,
- num_attention_heads: 71,
- layer_norm_epsilon: 1e-5,
- initializer_range: 0.02,
- use_cache: true,
- bos_token_id: 11,
- eos_token_id: 11,
- hidden_dropout: 0.0,
- attention_dropout: 0.0,
- n_head_kv: None,
- alibi: false,
- new_decoder_architecture: false,
- multi_query: true,
- parallel_attn: true,
- bias: false,
- }
- }
-}
-
-impl Config {
- pub fn validate(&self) -> Result<()> {
- if self.alibi {
- anyhow::bail!("alibi is not supported");
- }
- if self.new_decoder_architecture {
- anyhow::bail!("new_decoder_architecture is not supported");
- }
- if self.n_head_kv.is_some() {
- anyhow::bail!("n_head_kv is not supported");
- }
- Ok(())
- }
-
- // https://huggingface.co/tiiuae/falcon-7b/blob/main/config.json
- pub fn falcon7b() -> Self {
- // This is currently on par with the defaults, the defaults come from the Python default
- // arguments for the config initialization whereas the following come from the json config.
- Self {
- vocab_size: 65024,
- hidden_size: 4544,
- num_hidden_layers: 32,
- num_attention_heads: 71,
- layer_norm_epsilon: 1e-5,
- initializer_range: 0.02,
- use_cache: true,
- bos_token_id: 11,
- eos_token_id: 11,
- hidden_dropout: 0.,
- attention_dropout: 0.,
- n_head_kv: None,
- alibi: false,
- new_decoder_architecture: false,
- multi_query: true,
- parallel_attn: true,
- bias: false,
- }
- }
-
- fn head_dim(&self) -> usize {
- self.hidden_size / self.num_attention_heads
- }
-
- fn rotary(&self) -> bool {
- !self.alibi
- }
-}
-
-fn rotate_half(x: &Tensor) -> Result<Tensor> {
- let l = x.dim(D::Minus1)?;
- let x1 = x.narrow(D::Minus1, 0, l / 2)?;
- let x2 = x.narrow(D::Minus1, l / 2, l - l / 2)?;
- let x21 = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
- Ok(x21)
-}
-
-#[derive(Debug)]
-struct FalconRotaryEmbedding {
- inv_freq: Tensor,
- cache: Option<(usize, Tensor, Tensor)>,
-}
-
-impl FalconRotaryEmbedding {
- fn load(device: &Device, cfg: &Config) -> Result<Self> {
- let head_dim = cfg.head_dim();
- let inv_freq: Vec<_> = (0..head_dim)
- .step_by(2)
- .map(|i| 1f32 / 10000f32.powf(i as f32 / head_dim as f32))
- .collect();
- Ok(Self {
- inv_freq: Tensor::new(inv_freq.as_slice(), device)?,
- cache: None,
- })
- }
-
- fn cos_sin(
- &mut self,
- seq_len: usize,
- device: &Device,
- dtype: DType,
- ) -> Result<(Tensor, Tensor)> {
- match &self.cache {
- Some((s, cos, sin)) if *s == seq_len => {
- return Ok((cos.clone(), sin.clone()));
- }
- _ => {}
- }
- let t = Tensor::arange(0, seq_len as u32, device)?.to_dtype(dtype)?;
- let inv_freq = self.inv_freq.to_dtype(dtype)?;
- let freqs = t.unsqueeze(1)?.matmul(&inv_freq.unsqueeze(0)?)?;
- let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
- let cos = emb.cos()?;
- let sin = emb.sin()?;
- self.cache = Some((seq_len, cos.clone(), sin.clone()));
- Ok((cos, sin))
- }
-
- fn forward(
- &mut self,
- query: &Tensor,
- key: &Tensor,
- past_kv_len: usize,
- ) -> Result<(Tensor, Tensor)> {
- let (_batch, seq_len, _head_dim) = query.dims3()?;
- let (cos, sin) = self.cos_sin(MAX_SEQ_LEN, query.device(), query.dtype())?;
- let cos = cos.narrow(0, past_kv_len, seq_len)?;
- let sin = sin.narrow(0, past_kv_len, seq_len)?;
- let qs = (query.broadcast_mul(&cos)? + &rotate_half(query)?.broadcast_mul(&sin)?)?;
- let ks = (key.broadcast_mul(&cos)? + &rotate_half(key)?.broadcast_mul(&sin)?)?;
- Ok((qs, ks))
- }
-}
-
-fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
- let shape = mask.shape();
- let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
- let m = mask.where_cond(&on_true, on_false)?;
- Ok(m)
-}
-
-#[derive(Debug)]
-struct FalconAttention {
- query_key_value: Linear,
- dense: Linear,
- maybe_rotary: Option<FalconRotaryEmbedding>,
- kv_cache: Option<(Tensor, Tensor)>,
- inv_norm_factor: f64,
- multi_query: bool,
- use_cache: bool,
- num_heads: usize,
- head_dim: usize,
- n_head_kv: usize,
-}
-
-impl FalconAttention {
- fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
- let maybe_rotary = if cfg.rotary() {
- let rotary = FalconRotaryEmbedding::load(vb.device(), cfg)?;
- Some(rotary)
- } else {
- None
- };
- let head_dim = cfg.head_dim();
- let hidden_size = cfg.hidden_size;
- let qkv_out_dim = if cfg.multi_query {
- hidden_size + 2 * head_dim
- } else {
- 3 * hidden_size
- };
- let query_key_value = linear(hidden_size, qkv_out_dim, cfg.bias, vb.pp("query_key_value"))?;
- let dense = linear(hidden_size, hidden_size, cfg.bias, vb.pp("dense"))?;
- Ok(Self {
- query_key_value,
- dense,
- maybe_rotary,
- kv_cache: None,
- inv_norm_factor: 1. / (head_dim as f64).sqrt(),
- multi_query: cfg.multi_query,
- use_cache: cfg.use_cache,
- num_heads: cfg.num_attention_heads,
- n_head_kv: cfg.n_head_kv.unwrap_or(1),
- head_dim,
- })
- }
-
- fn split_heads(&self, fused_qkv: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
- let (b_sz, seq_len, _) = fused_qkv.dims3()?;
- if !self.multi_query {
- let fused_qkv = fused_qkv.reshape((b_sz, seq_len, self.num_heads, 3, self.head_dim))?;
- let q = fused_qkv.narrow(D::Minus2, 0, 1)?.squeeze(D::Minus2)?;
- let k = fused_qkv.narrow(D::Minus2, 1, 1)?.squeeze(D::Minus2)?;
- let v = fused_qkv.narrow(D::Minus2, 2, 1)?.squeeze(D::Minus2)?;
- Ok((q, k, v))
- } else {
- let fused_qkv =
- fused_qkv.reshape((b_sz, seq_len, self.num_heads + 2, self.head_dim))?;
- let d = fused_qkv.dim(D::Minus2)?;
- let q = fused_qkv.narrow(D::Minus2, 0, d - 2)?;
- let k = fused_qkv.narrow(D::Minus2, d - 2, 1)?;
- let v = fused_qkv.narrow(D::Minus2, d - 1, 1)?;
- Ok((q, k, v))
- }
- }
-
- fn forward(&mut self, x: &Tensor, mask: &Tensor, past_kv_len: usize) -> Result<Tensor> {
- let fused_qkv = self.query_key_value.forward(x)?;
- let head_dim = self.head_dim;
- let (query, key, value) = self.split_heads(&fused_qkv)?;
- let (b_sz, seq_len, _, _) = query.dims4()?;
- let query = query
- .transpose(1, 2)?
- .reshape((b_sz * self.num_heads, seq_len, head_dim))?;
- let key = key
- .transpose(1, 2)?
- .reshape((b_sz * self.n_head_kv, seq_len, head_dim))?;
- let value = value
- .transpose(1, 2)?
- .reshape((b_sz * self.n_head_kv, seq_len, head_dim))?;
- let (query, key) = if let Some(r) = &mut self.maybe_rotary {
- r.forward(&query, &key, past_kv_len)?
- } else {
- (query, key)
- };
- let (mut key, mut value) = (key, value);
- let mask = masked_fill(&mask.to_dtype(DType::F32)?, mask, -1e9)?.to_dtype(query.dtype())?;
- if self.use_cache {
- if let Some((cache_k, cache_v)) = &self.kv_cache {
- // TODO: we could trim the tensors to MAX_SEQ_LEN so that this would work for
- // arbitrarily large sizes.
- key = Tensor::cat(&[cache_k, &key], 1)?.contiguous()?;
- value = Tensor::cat(&[cache_v, &value], 1)?.contiguous()?;
- }
- self.kv_cache = Some((key.clone(), value.clone()))
- }
- let query = query.reshape((b_sz * self.num_heads, seq_len, head_dim))?;
- let all_len = past_kv_len + seq_len;
- let key = key.reshape((b_sz * self.n_head_kv, all_len, head_dim))?;
- let value = value.reshape((b_sz * self.n_head_kv, all_len, head_dim))?;
-
- let (key, value) = if self.n_head_kv == 1 {
- (
- key.broadcast_as((b_sz * self.num_heads, all_len, head_dim))?,
- value.broadcast_as((b_sz * self.num_heads, all_len, head_dim))?,
- )
- } else {
- (key, value)
- };
-
- // Only handle the case where alibi is None here, and non-flash attention.
- let attention_scores = (query.matmul(&key.t()?)? * self.inv_norm_factor)?;
- let attention_scores = candle_nn::ops::softmax(
- &attention_scores
- .broadcast_add(&mask.squeeze(1)?)?
- .to_dtype(DType::F32)?,
- D::Minus1,
- )?
- .to_dtype(x.dtype())?;
- let attn_output = attention_scores
- .matmul(&value)?
- .reshape((b_sz, self.num_heads, seq_len, head_dim))?
- .transpose(1, 2)?
- .reshape((b_sz, seq_len, self.num_heads * head_dim))?;
- let attn_output = self.dense.forward(&attn_output)?;
- Ok(attn_output)
- }
-}
-
-#[derive(Debug)]
-struct FalconMlp {
- dense_h_to_4h: Linear,
- dense_4h_to_h: Linear,
-}
-
-impl FalconMlp {
- fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
- let h = cfg.hidden_size;
- let b = cfg.bias;
- let dense_h_to_4h = linear(h, 4 * h, b, vb.pp("dense_h_to_4h"))?;
- let dense_4h_to_h = linear(4 * h, h, b, vb.pp("dense_4h_to_h"))?;
- Ok(Self {
- dense_h_to_4h,
- dense_4h_to_h,
- })
- }
-
- fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let x = self.dense_h_to_4h.forward(x)?.gelu()?;
- let x = self.dense_4h_to_h.forward(&x)?;
- Ok(x)
- }
-}
-
-#[derive(Debug)]
-struct FalconDecoderLayer {
- inp_layernorm: LayerNorm,
- self_attention: FalconAttention,
- post_attention_layernorm: Option<LayerNorm>,
- mlp: FalconMlp,
- parallel_attn: bool,
-}
-
-impl FalconDecoderLayer {
- fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
- let mlp = FalconMlp::load(vb.pp("mlp"), cfg)?;
- let inp_layernorm = layer_norm(
- cfg.hidden_size,
- cfg.layer_norm_epsilon,
- vb.pp("input_layernorm"),
- )?;
- let self_attention = FalconAttention::load(vb.pp("self_attention"), cfg)?;
- let post_attention_layernorm = if cfg.parallel_attn {
- None
- } else {
- let ln = layer_norm(
- cfg.hidden_size,
- cfg.layer_norm_epsilon,
- vb.pp("post_attention_layernorm"),
- )?;
- Some(ln)
- };
- Ok(Self {
- inp_layernorm,
- self_attention,
- post_attention_layernorm,
- mlp,
- parallel_attn: cfg.parallel_attn,
- })
- }
-
- fn forward(&mut self, x: &Tensor, mask: &Tensor, past_kv_len: usize) -> Result<Tensor> {
- let residual = x.clone();
- let ln_attn = self.inp_layernorm.forward(x)?;
- let attn_output = self.self_attention.forward(&ln_attn, mask, past_kv_len)?;
- let (residual, ln_mlp) = match &self.post_attention_layernorm {
- None => (residual, ln_attn),
- Some(pal) => {
- // This should include some dropout.
- let residual = (&attn_output + &residual)?;
- let ln_mlp = pal.forward(&residual)?;
- (residual, ln_mlp)
- }
- };
- let mlp_output = self.mlp.forward(&ln_mlp)?;
-
- let mlp_output = if self.parallel_attn {
- (mlp_output + attn_output)?
- } else {
- mlp_output
- };
- let output = (mlp_output + residual)?;
- Ok(output)
- }
-}
-
-#[derive(Debug)]
-pub struct Falcon {
- word_embeddings: Embedding,
- blocks: Vec<FalconDecoderLayer>,
- ln_f: LayerNorm,
- lm_head: Linear,
- config: Config,
-}
-
-fn make_causal_mask(t: usize) -> Result<Tensor> {
- let mask: Vec<_> = (0..t)
- .flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
- .collect();
- let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?;
- Ok(mask)
-}
-
-fn prepare_attn_mask(b_sz: usize, seq_len: usize) -> Result<Tensor> {
- // let mask = Tensor::ones((b_sz, seq_len), DType::U32, &Device::Cpu)?;
- let mask = make_causal_mask(seq_len)?;
- let mask = mask.broadcast_as((b_sz, 1, seq_len, seq_len))?;
- Ok(mask)
-}
-
-impl Falcon {
- pub fn config(&self) -> &Config {
- &self.config
- }
-
- pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> {
- let word_embeddings = embedding(
- cfg.vocab_size,
- cfg.hidden_size,
- vb.pp("transformer.word_embeddings"),
- )?;
- let blocks = (0..cfg.num_hidden_layers)
- .map(|i| FalconDecoderLayer::load(vb.pp(&format!("transformer.h.{i}")), &cfg))
- .collect::<Result<Vec<_>>>()?;
- let ln_f = layer_norm(
- cfg.hidden_size,
- cfg.layer_norm_epsilon,
- vb.pp("transformer.ln_f"),
- )?;
- let lm_head = linear(cfg.hidden_size, cfg.vocab_size, false, vb.pp("lm_head"))?;
- Ok(Self {
- word_embeddings,
- blocks,
- ln_f,
- lm_head,
- config: cfg,
- })
- }
-
- pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
- let (b_sz, seq_len) = input_ids.dims2()?;
- let mut hidden_state = self.word_embeddings.forward(input_ids)?;
- let past_kv_len = match &self.blocks[0].self_attention.kv_cache {
- Some((k, _)) => k.dim(1)?,
- None => 0,
- };
- let causal_mask = prepare_attn_mask(b_sz, seq_len)?.to_device(input_ids.device())?;
- for block in self.blocks.iter_mut() {
- hidden_state = block.forward(&hidden_state, &causal_mask, past_kv_len)?;
- }
- let hidden_state = self.ln_f.forward(&hidden_state)?;
- let hidden_state = hidden_state.narrow(1, seq_len - 1, 1)?;
- let logits = self.lm_head.forward(&hidden_state)?.squeeze(1)?;
- Ok(logits)
- }
-}