diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-12-15 14:19:56 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-15 14:19:56 -0600 |
commit | 614842b311a12ac5aba130e165763f997d8ff324 (patch) | |
tree | 6f10f4f5dd3b966ff20d1876914d00d8fd2473c9 /candle-transformers | |
parent | 79eab519fdb4a84cf080568ec9faaa2e43953bf6 (diff) | |
download | candle-614842b311a12ac5aba130e165763f997d8ff324.tar.gz candle-614842b311a12ac5aba130e165763f997d8ff324.tar.bz2 candle-614842b311a12ac5aba130e165763f997d8ff324.zip |
Add the Mixtral model. (#1437)
* Add the Mixtral model.
* Add more of the mixtral layers.
* Add the final layers for mixtral.
* Sketch the expert selection.
* Add some expert routing logic.
* Hopefully finish the routing logic for mixtral.
* Add the mixtral example.
* Fix the weight filenames.
* Bugfix.
* Another fix.
* Yet another fix + remove the unused pragma.
* Shape fix.
* Add a readme.
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/mixtral.rs | 499 | ||||
-rw-r--r-- | candle-transformers/src/models/mod.rs | 1 |
2 files changed, 500 insertions, 0 deletions
diff --git a/candle-transformers/src/models/mixtral.rs b/candle-transformers/src/models/mixtral.rs new file mode 100644 index 00000000..ede74d3f --- /dev/null +++ b/candle-transformers/src/models/mixtral.rs @@ -0,0 +1,499 @@ +use crate::models::with_tracing::{linear_no_bias, Linear}; +/// Mixtral Model +/// https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py +/// https://mistral.ai/news/mixtral-of-experts/ +use candle::{DType, Device, Module, Result, Tensor, D}; +use candle_nn::{Activation, VarBuilder}; +use serde::Deserialize; +use std::sync::Arc; + +/// https://github.com/huggingface/transformers/blob/1a585c1222a56bcaecc070966d558d4a9d862e83/src/transformers/models/mixtral/configuration_mixtral.py#L113 +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct Config { + pub(crate) vocab_size: usize, + pub(crate) hidden_size: usize, + pub(crate) intermediate_size: usize, + pub(crate) num_hidden_layers: usize, + pub(crate) num_attention_heads: usize, + pub(crate) num_key_value_heads: usize, + pub(crate) hidden_act: Activation, + pub(crate) max_position_embeddings: usize, + pub(crate) rms_norm_eps: f64, + pub(crate) rope_theta: f64, + pub(crate) sliding_window: usize, + pub(crate) num_experts_per_tok: usize, + pub(crate) num_local_experts: usize, + pub(crate) use_flash_attn: bool, +} + +impl Config { + /// https://huggingface.co/mistralai/Mixtral-8x7B-v0.1/blob/main/config.json + pub fn v0_1_8x7b(use_flash_attn: bool) -> Self { + Self { + vocab_size: 32000, + hidden_size: 4096, + intermediate_size: 14336, + num_hidden_layers: 32, + num_attention_heads: 32, + num_key_value_heads: 8, + hidden_act: Activation::Silu, + max_position_embeddings: 32768, + rms_norm_eps: 1e-5, + rope_theta: 1e6, + sliding_window: 4096, + num_experts_per_tok: 2, + num_local_experts: 8, + use_flash_attn, + } + } +} + +#[derive(Debug, Clone)] +struct RmsNorm { + inner: candle_nn::RmsNorm, + span: tracing::Span, +} + +impl RmsNorm { + fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); + let inner = candle_nn::rms_norm(size, eps, vb)?; + Ok(Self { inner, span }) + } +} + +impl Module for RmsNorm { + fn forward(&self, x: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(x) + } +} + +#[derive(Debug, Clone)] +struct RotaryEmbedding { + sin: Tensor, + cos: Tensor, +} + +fn rotate_half(xs: &Tensor) -> Result<Tensor> { + let last_dim = xs.dim(D::Minus1)?; + let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?; + let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?; + Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1) +} + +impl RotaryEmbedding { + fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> { + let dim = cfg.hidden_size / cfg.num_attention_heads; + let max_seq_len = cfg.max_position_embeddings; + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / (cfg.rope_theta as f32).powf(i as f32 / dim as f32)) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + let t = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(dtype)? + .reshape((max_seq_len, 1))?; + let freqs = t.matmul(&inv_freq)?; + let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; + Ok(Self { + sin: freqs.sin()?, + cos: freqs.cos()?, + }) + } + + fn apply_rotary_emb_qkv( + &self, + q: &Tensor, + k: &Tensor, + seqlen_offset: usize, + ) -> Result<(Tensor, Tensor)> { + let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; + let cos = self.cos.narrow(0, seqlen_offset, seq_len)?; + let sin = self.sin.narrow(0, seqlen_offset, seq_len)?; + let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) + let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) + let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?; + let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?; + Ok((q_embed, k_embed)) + } +} + +#[cfg(feature = "flash-attn")] +fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result<Tensor> { + candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) +} + +#[cfg(not(feature = "flash-attn"))] +fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> { + unimplemented!("compile with '--features flash-attn'") +} + +#[derive(Debug, Clone)] +struct Attention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + o_proj: Linear, + num_heads: usize, + num_kv_heads: usize, + num_kv_groups: usize, + head_dim: usize, + hidden_size: usize, + rotary_emb: Arc<RotaryEmbedding>, + kv_cache: Option<(Tensor, Tensor)>, + use_flash_attn: bool, +} + +impl Attention { + fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> { + let hidden_sz = cfg.hidden_size; + let num_heads = cfg.num_attention_heads; + let num_kv_heads = cfg.num_key_value_heads; + let num_kv_groups = num_heads / num_kv_heads; + let head_dim = hidden_sz / num_heads; + let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; + let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; + let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; + let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + num_heads, + num_kv_heads, + num_kv_groups, + head_dim, + hidden_size: hidden_sz, + rotary_emb, + kv_cache: None, + use_flash_attn: cfg.use_flash_attn, + }) + } + + fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> { + let n_rep = self.num_kv_groups; + if n_rep == 1 { + Ok(xs) + } else { + let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?; + xs.unsqueeze(2)? + .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))? + .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim)) + } + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result<Tensor> { + let (b_sz, q_len, _) = xs.dims3()?; + + let query_states = self.q_proj.forward(xs)?; + let key_states = self.k_proj.forward(xs)?; + let value_states = self.v_proj.forward(xs)?; + + let query_states = query_states + .reshape((b_sz, q_len, self.num_heads, self.head_dim))? + .transpose(1, 2)?; + let key_states = key_states + .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + let value_states = value_states + .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + + let (query_states, key_states) = + self.rotary_emb + .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?; + + let (key_states, value_states) = match &self.kv_cache { + None => (key_states, value_states), + Some((prev_k, prev_v)) => { + let key_states = Tensor::cat(&[prev_k, &key_states], 2)?; + let value_states = Tensor::cat(&[prev_v, &value_states], 2)?; + (key_states, value_states) + } + }; + self.kv_cache = Some((key_states.clone(), value_states.clone())); + + let key_states = self.repeat_kv(key_states)?; + let value_states = self.repeat_kv(value_states)?; + + let attn_output = if self.use_flash_attn { + // flash-attn expects (b_sz, seq_len, nheads, head_dim) + let q = query_states.transpose(1, 2)?; + let k = key_states.transpose(1, 2)?; + let v = value_states.transpose(1, 2)?; + let softmax_scale = 1f32 / (self.head_dim as f32).sqrt(); + flash_attn(&q, &k, &v, softmax_scale, q_len > 1)?.transpose(1, 2)? + } else { + let scale = 1f64 / f64::sqrt(self.head_dim as f64); + let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?; + + let attn_weights = match attention_mask { + None => attn_weights, + Some(mask) => attn_weights.broadcast_add(mask)?, + }; + let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + attn_weights.matmul(&value_states)? + }; + attn_output + .transpose(1, 2)? + .reshape((b_sz, q_len, self.hidden_size))? + .apply(&self.o_proj) + } +} + +#[derive(Debug, Clone)] +struct BlockSparseTop2MLP { + w1: Linear, + w2: Linear, + w3: Linear, + act_fn: Activation, +} + +impl BlockSparseTop2MLP { + fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let hidden_sz = cfg.hidden_size; + let intermediate_sz = cfg.intermediate_size; + let w1 = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("w1"))?; + let w2 = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("w2"))?; + let w3 = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("w3"))?; + Ok(Self { + w1, + w2, + w3, + act_fn: cfg.hidden_act, + }) + } +} + +impl Module for BlockSparseTop2MLP { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let lhs = xs.apply(&self.w1)?.apply(&self.act_fn)?; + let rhs = xs.apply(&self.w3)?; + (lhs * rhs)?.apply(&self.w2) + } +} + +#[derive(Debug, Clone)] +struct SparseMoeBlock { + gate: Linear, + experts: Vec<BlockSparseTop2MLP>, + num_experts_per_tok: usize, +} + +impl SparseMoeBlock { + fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let gate = linear_no_bias(cfg.hidden_size, cfg.num_local_experts, vb.pp("gate"))?; + let mut experts = Vec::with_capacity(cfg.num_local_experts); + let vb = vb.pp("experts"); + for idx in 0..cfg.num_local_experts { + let expert = BlockSparseTop2MLP::new(cfg, vb.pp(idx))?; + experts.push(expert) + } + Ok(SparseMoeBlock { + gate, + experts, + num_experts_per_tok: cfg.num_experts_per_tok, + }) + } +} + +impl Module for SparseMoeBlock { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let (b_size, seq_len, hidden_dim) = xs.dims3()?; + let xs = xs.reshape(((), hidden_dim))?; + let router_logits = xs.apply(&self.gate)?; + let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?; + + // In order to extract topk, we extract the data from the tensor and manipulate it + // directly. Maybe we will want to use some custom ops instead at some point. + let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?; + + // routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + // top_x contains the row indexes to evaluate for each expert. + let mut top_x = vec![vec![]; self.experts.len()]; + let mut selected_rws = vec![vec![]; self.experts.len()]; + for (row_idx, rw) in routing_weights.iter().enumerate() { + let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>(); + dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize])); + let mut sum_routing_weights = 0f32; + for &expert_idx in dst.iter().take(self.num_experts_per_tok) { + let expert_idx = expert_idx as usize; + let routing_weight = rw[expert_idx]; + sum_routing_weights += routing_weight; + top_x[expert_idx].push(row_idx as u32); + } + for &expert_idx in dst.iter().take(self.num_experts_per_tok) { + let expert_idx = expert_idx as usize; + let routing_weight = rw[expert_idx]; + selected_rws[expert_idx].push(routing_weight / sum_routing_weights) + } + } + + // routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + // expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + let mut ys = xs.zeros_like()?; + for (expert_idx, expert_layer) in self.experts.iter().enumerate() { + let top_x = &top_x[expert_idx]; + if top_x.is_empty() { + continue; + } + let top_x = Tensor::new(top_x.as_slice(), xs.device())?; + let selected_rws = + Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?.reshape(((), 1))?; + // Index the correct hidden states and compute the expert hidden state for + // the current expert. We need to make sure to multiply the output hidden + // states by `routing_weights` on the corresponding tokens (top-1 and top-2) + let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?; + // current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None]) + let current_hidden_states = expert_layer.forward(¤t_state)?; + let current_hidden_states = current_hidden_states.broadcast_mul(&selected_rws)?; + ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?; + } + + let ys = ys.reshape((b_size, seq_len, hidden_dim))?; + Ok(ys) + } +} + +#[derive(Debug, Clone)] +struct DecoderLayer { + self_attn: Attention, + block_sparse_moe: SparseMoeBlock, + input_layernorm: RmsNorm, + post_attention_layernorm: RmsNorm, +} + +impl DecoderLayer { + fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> { + let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?; + let block_sparse_moe = SparseMoeBlock::new(cfg, vb.pp("block_sparse_moe"))?; + let input_layernorm = + RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; + let post_attention_layernorm = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("post_attention_layernorm"), + )?; + Ok(Self { + self_attn, + block_sparse_moe, + input_layernorm, + post_attention_layernorm, + }) + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result<Tensor> { + let residual = xs; + let xs = self.input_layernorm.forward(xs)?; + let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?; + let xs = (xs + residual)?; + let residual = &xs; + let xs = xs + .apply(&self.post_attention_layernorm)? + .apply(&self.block_sparse_moe)?; + residual + xs + } +} + +#[derive(Debug, Clone)] +pub struct Model { + embed_tokens: candle_nn::Embedding, + layers: Vec<DecoderLayer>, + norm: RmsNorm, + lm_head: Linear, + sliding_window: usize, + device: Device, + dtype: DType, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let vb_m = vb.pp("model"); + let embed_tokens = + candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; + let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?); + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + let vb_l = vb_m.pp("layers"); + for layer_idx in 0..cfg.num_hidden_layers { + let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?; + layers.push(layer) + } + let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?; + let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; + Ok(Self { + embed_tokens, + layers, + norm, + lm_head, + sliding_window: cfg.sliding_window, + device: vb.device().clone(), + dtype: vb.dtype(), + }) + } + + fn prepare_decoder_attention_mask( + &self, + b_size: usize, + tgt_len: usize, + seqlen_offset: usize, + ) -> Result<Tensor> { + // Sliding window mask? + let mask: Vec<_> = (0..tgt_len) + .flat_map(|i| { + (0..tgt_len).map(move |j| { + if i < j || j + self.sliding_window < i { + f32::NEG_INFINITY + } else { + 0. + } + }) + }) + .collect(); + let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?; + let mask = if seqlen_offset > 0 { + let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?; + Tensor::cat(&[&mask0, &mask], D::Minus1)? + } else { + mask + }; + mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))? + .to_dtype(self.dtype) + } + + pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> { + let (b_size, seq_len) = input_ids.dims2()?; + let attention_mask = if seq_len <= 1 { + None + } else { + let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?; + Some(mask) + }; + let mut xs = self.embed_tokens.forward(input_ids)?; + for layer in self.layers.iter_mut() { + xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)? + } + xs.narrow(1, seq_len - 1, 1)? + .apply(&self.norm)? + .apply(&self.lm_head) + } +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index a9a56673..94a3bd5b 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -14,6 +14,7 @@ pub mod llama2_c_weights; pub mod marian; pub mod mistral; pub mod mixformer; +pub mod mixtral; pub mod mpt; pub mod persimmon; pub mod quantized_blip; |