diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-08-17 19:31:23 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-08-17 20:31:23 +0200 |
commit | c1b9e07e3549574659b189389975c1152b0776f5 (patch) | |
tree | 25f30a3c0dc483a86f920e609c5dd1e52594855d /candle-transformers | |
parent | 69fdcfe96ac05213b3b166140774f38a99de0b54 (diff) | |
download | candle-c1b9e07e3549574659b189389975c1152b0776f5.tar.gz candle-c1b9e07e3549574659b189389975c1152b0776f5.tar.bz2 candle-c1b9e07e3549574659b189389975c1152b0776f5.zip |
Add support for gemma-2. (#2425)
* Add gemma-2.
* Support a couple more models.
* Sliding window support.
* Example + readme updates.
* Update the main readme.
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/gemma2.rs | 449 | ||||
-rw-r--r-- | candle-transformers/src/models/mod.rs | 1 |
2 files changed, 450 insertions, 0 deletions
diff --git a/candle-transformers/src/models/gemma2.rs b/candle-transformers/src/models/gemma2.rs new file mode 100644 index 00000000..f0d65047 --- /dev/null +++ b/candle-transformers/src/models/gemma2.rs @@ -0,0 +1,449 @@ +use std::sync::Arc; + +use candle::{DType, Device, Module, Result, Tensor, D}; +use candle_nn::{linear_b as linear, Activation, Linear, VarBuilder}; + +fn default_max_position_embeddings() -> usize { + 4096 +} + +#[derive(serde::Deserialize, Debug, Clone)] +pub struct Config { + pub attention_bias: bool, + pub head_dim: usize, + pub hidden_activation: Activation, + pub hidden_size: usize, + pub intermediate_size: usize, + pub num_attention_heads: usize, + pub num_hidden_layers: usize, + pub num_key_value_heads: usize, + pub rms_norm_eps: f64, + pub rope_theta: f64, + pub vocab_size: usize, + pub final_logit_softcapping: Option<f64>, + pub attn_logit_softcapping: Option<f64>, + pub query_pre_attn_scalar: usize, + // TODO: Handle the sliding window in the attention mask. + pub sliding_window: Option<usize>, + + #[serde(default = "default_max_position_embeddings")] + pub max_position_embeddings: usize, +} + +#[derive(Debug, Clone)] +struct RmsNorm { + weight: Tensor, + eps: f64, +} + +impl RmsNorm { + fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result<Self> { + let weight = vb.get(dim, "weight")?; + Ok(Self { weight, eps }) + } +} + +impl Module for RmsNorm { + fn forward(&self, x: &Tensor) -> Result<Tensor> { + let x_dtype = x.dtype(); + let internal_dtype = match x_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + let hidden_size = x.dim(D::Minus1)?; + let x = x.to_dtype(internal_dtype)?; + let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; + let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; + x_normed + .to_dtype(x_dtype)? + .broadcast_mul(&(&self.weight + 1.0)?) + } +} + +#[derive(Debug, Clone)] +struct RotaryEmbedding { + sin: Tensor, + cos: Tensor, +} + +impl RotaryEmbedding { + fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> { + let dim = cfg.head_dim; + let max_seq_len = cfg.max_position_embeddings; + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + let t = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(dtype)? + .reshape((max_seq_len, 1))?; + let freqs = t.matmul(&inv_freq)?; + Ok(Self { + sin: freqs.sin()?, + cos: freqs.cos()?, + }) + } + + fn apply_rotary_emb_qkv( + &self, + q: &Tensor, + k: &Tensor, + seqlen_offset: usize, + ) -> Result<(Tensor, Tensor)> { + let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; + let cos = self.cos.narrow(0, seqlen_offset, seq_len)?; + let sin = self.sin.narrow(0, seqlen_offset, seq_len)?; + let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?; + let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?; + Ok((q_embed, k_embed)) + } +} + +#[derive(Debug, Clone)] +#[allow(clippy::upper_case_acronyms)] +struct MLP { + gate_proj: Linear, + up_proj: Linear, + down_proj: Linear, + act_fn: candle_nn::Activation, +} + +impl MLP { + fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let hidden_sz = cfg.hidden_size; + let intermediate_sz = cfg.intermediate_size; + let gate_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("gate_proj"))?; + let up_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("up_proj"))?; + let down_proj = linear(intermediate_sz, hidden_sz, false, vb.pp("down_proj"))?; + Ok(Self { + gate_proj, + up_proj, + down_proj, + act_fn: cfg.hidden_activation, + }) + } +} + +impl Module for MLP { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?; + let rhs = xs.apply(&self.up_proj)?; + (lhs * rhs)?.apply(&self.down_proj) + } +} + +#[derive(Debug, Clone)] +struct Attention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + o_proj: Linear, + num_heads: usize, + num_kv_heads: usize, + num_kv_groups: usize, + head_dim: usize, + attn_logit_softcapping: Option<f64>, + rotary_emb: Arc<RotaryEmbedding>, + kv_cache: Option<(Tensor, Tensor)>, + use_flash_attn: bool, +} + +impl Attention { + fn new( + rotary_emb: Arc<RotaryEmbedding>, + use_flash_attn: bool, + cfg: &Config, + vb: VarBuilder, + ) -> Result<Self> { + let hidden_sz = cfg.hidden_size; + let num_heads = cfg.num_attention_heads; + let num_kv_heads = cfg.num_key_value_heads; + let num_kv_groups = num_heads / num_kv_heads; + let head_dim = cfg.head_dim; + let bias = cfg.attention_bias; + let q_proj = linear(hidden_sz, num_heads * head_dim, bias, vb.pp("q_proj"))?; + let k_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("k_proj"))?; + let v_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("v_proj"))?; + let o_proj = linear(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?; + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + num_heads, + num_kv_heads, + num_kv_groups, + head_dim, + attn_logit_softcapping: cfg.attn_logit_softcapping, + rotary_emb, + kv_cache: None, + use_flash_attn, + }) + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result<Tensor> { + let (b_sz, q_len, _) = xs.dims3()?; + + let query_states = self.q_proj.forward(xs)?; + let key_states = self.k_proj.forward(xs)?; + let value_states = self.v_proj.forward(xs)?; + + let query_states = query_states + .reshape((b_sz, q_len, self.num_heads, self.head_dim))? + .transpose(1, 2)?; + let key_states = key_states + .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + let value_states = value_states + .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + + let (query_states, key_states) = + self.rotary_emb + .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?; + + let (key_states, value_states) = match &self.kv_cache { + None => (key_states, value_states), + Some((prev_k, prev_v)) => { + let key_states = Tensor::cat(&[prev_k, &key_states], 2)?; + let value_states = Tensor::cat(&[prev_v, &value_states], 2)?; + (key_states, value_states) + } + }; + self.kv_cache = Some((key_states.clone(), value_states.clone())); + + let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?; + let value_states = + crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?; + + let attn_output = if self.use_flash_attn { + // flash-attn expects (b_sz, seq_len, nheads, head_dim) + let q = query_states.transpose(1, 2)?; + let k = key_states.transpose(1, 2)?; + let v = value_states.transpose(1, 2)?; + let scale = 1f32 / (self.head_dim as f32).sqrt(); + flash_attn(&q, &k, &v, scale, attention_mask.is_some())?.transpose(1, 2)? + } else { + let scale = 1f64 / f64::sqrt(self.head_dim as f64); + let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?; + + let attn_weights = match self.attn_logit_softcapping { + None => attn_weights, + Some(sc) => ((attn_weights / sc)?.tanh()? * sc)?, + }; + + let attn_weights = match attention_mask { + None => attn_weights, + Some(mask) => attn_weights.broadcast_add(mask)?, + }; + let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + attn_weights.matmul(&value_states)? + }; + attn_output + .transpose(1, 2)? + .reshape((b_sz, q_len, ()))? + .apply(&self.o_proj) + } + + fn clear_kv_cache(&mut self) { + self.kv_cache = None + } +} + +#[cfg(feature = "flash-attn")] +fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result<Tensor> { + candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) +} + +#[cfg(not(feature = "flash-attn"))] +fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> { + unimplemented!("compile with '--features flash-attn'") +} + +#[derive(Debug, Clone)] +struct DecoderLayer { + self_attn: Attention, + mlp: MLP, + input_layernorm: RmsNorm, + pre_feedforward_layernorm: RmsNorm, + post_feedforward_layernorm: RmsNorm, + post_attention_layernorm: RmsNorm, +} + +impl DecoderLayer { + fn new( + rotary_emb: Arc<RotaryEmbedding>, + use_flash_attn: bool, + cfg: &Config, + vb: VarBuilder, + ) -> Result<Self> { + let self_attn = Attention::new(rotary_emb, use_flash_attn, cfg, vb.pp("self_attn"))?; + let mlp = MLP::new(cfg, vb.pp("mlp"))?; + let input_layernorm = + RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; + let pre_feedforward_layernorm = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("pre_feedforward_layernorm"), + )?; + let post_feedforward_layernorm = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("post_feedforward_layernorm"), + )?; + let post_attention_layernorm = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("post_attention_layernorm"), + )?; + Ok(Self { + self_attn, + mlp, + input_layernorm, + pre_feedforward_layernorm, + post_feedforward_layernorm, + post_attention_layernorm, + }) + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result<Tensor> { + let residual = xs; + let xs = self.input_layernorm.forward(xs)?; + let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?; + let xs = xs.apply(&self.post_attention_layernorm)?; + let xs = (xs + residual)?; + let residual = &xs; + let xs = xs.apply(&self.pre_feedforward_layernorm)?; + let xs = xs.apply(&self.mlp)?; + let xs = xs.apply(&self.post_feedforward_layernorm)?; + residual + xs + } + + fn clear_kv_cache(&mut self) { + self.self_attn.clear_kv_cache() + } +} + +#[derive(Debug, Clone)] +pub struct Model { + embed_tokens: candle_nn::Embedding, + layers: Vec<DecoderLayer>, + norm: RmsNorm, + lm_head: Linear, + final_logit_softcapping: Option<f64>, + device: Device, + dtype: DType, + hidden_size: usize, + sliding_window: Option<usize>, +} + +impl Model { + pub fn new(use_flash_attn: bool, cfg: &Config, vb: VarBuilder) -> Result<Self> { + let vb_m = vb.pp("model"); + let embed_tokens = + candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; + let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?); + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + let vb_l = vb_m.pp("layers"); + for layer_idx in 0..cfg.num_hidden_layers { + let layer = + DecoderLayer::new(rotary_emb.clone(), use_flash_attn, cfg, vb_l.pp(layer_idx))?; + layers.push(layer) + } + let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?; + let lm_head = Linear::new(embed_tokens.embeddings().clone(), None); + Ok(Self { + embed_tokens, + layers, + norm, + lm_head, + final_logit_softcapping: cfg.final_logit_softcapping, + device: vb.device().clone(), + dtype: vb.dtype(), + hidden_size: cfg.hidden_size, + sliding_window: cfg.sliding_window, + }) + } + + fn prepare_decoder_attention_mask( + &self, + b_size: usize, + tgt_len: usize, + seqlen_offset: usize, + ) -> Result<Tensor> { + let mask: Vec<_> = match self.sliding_window { + None => (0..tgt_len) + .flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. })) + .collect(), + Some(sliding_window) => (0..tgt_len) + .flat_map(|i| { + (0..tgt_len).map(move |j| { + if i < j || j + sliding_window < i { + f32::NEG_INFINITY + } else { + 0. + } + }) + }) + .collect(), + }; + let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?; + let mask = if seqlen_offset > 0 { + let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?; + Tensor::cat(&[&mask0, &mask], D::Minus1)? + } else { + mask + }; + mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))? + .to_dtype(self.dtype) + } + + pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> { + let (b_size, seq_len) = input_ids.dims2()?; + let attention_mask = if seq_len <= 1 { + None + } else { + let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?; + Some(mask) + }; + let xs = self.embed_tokens.forward(input_ids)?; + let mut xs = (xs * (self.hidden_size as f64).sqrt())?; + for layer in self.layers.iter_mut() { + xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)? + } + let logits = xs + .narrow(1, seq_len - 1, 1)? + .apply(&self.norm)? + .apply(&self.lm_head)?; + let logits = match self.final_logit_softcapping { + None => logits, + Some(sc) => ((logits / sc)?.tanh()? * sc)?, + }; + + Ok(logits) + } + + pub fn clear_kv_cache(&mut self) { + for layer in self.layers.iter_mut() { + layer.clear_kv_cache() + } + } +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 7baaaf72..16952c6a 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -20,6 +20,7 @@ pub mod eva2; pub mod falcon; pub mod flux; pub mod gemma; +pub mod gemma2; pub mod glm4; pub mod hiera; pub mod jina_bert; |