summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-12-15 14:19:56 -0600
committerGitHub <noreply@github.com>2023-12-15 14:19:56 -0600
commit614842b311a12ac5aba130e165763f997d8ff324 (patch)
tree6f10f4f5dd3b966ff20d1876914d00d8fd2473c9 /candle-transformers
parent79eab519fdb4a84cf080568ec9faaa2e43953bf6 (diff)
downloadcandle-614842b311a12ac5aba130e165763f997d8ff324.tar.gz
candle-614842b311a12ac5aba130e165763f997d8ff324.tar.bz2
candle-614842b311a12ac5aba130e165763f997d8ff324.zip
Add the Mixtral model. (#1437)
* Add the Mixtral model. * Add more of the mixtral layers. * Add the final layers for mixtral. * Sketch the expert selection. * Add some expert routing logic. * Hopefully finish the routing logic for mixtral. * Add the mixtral example. * Fix the weight filenames. * Bugfix. * Another fix. * Yet another fix + remove the unused pragma. * Shape fix. * Add a readme.
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/mixtral.rs499
-rw-r--r--candle-transformers/src/models/mod.rs1
2 files changed, 500 insertions, 0 deletions
diff --git a/candle-transformers/src/models/mixtral.rs b/candle-transformers/src/models/mixtral.rs
new file mode 100644
index 00000000..ede74d3f
--- /dev/null
+++ b/candle-transformers/src/models/mixtral.rs
@@ -0,0 +1,499 @@
+use crate::models::with_tracing::{linear_no_bias, Linear};
+/// Mixtral Model
+/// https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
+/// https://mistral.ai/news/mixtral-of-experts/
+use candle::{DType, Device, Module, Result, Tensor, D};
+use candle_nn::{Activation, VarBuilder};
+use serde::Deserialize;
+use std::sync::Arc;
+
+/// https://github.com/huggingface/transformers/blob/1a585c1222a56bcaecc070966d558d4a9d862e83/src/transformers/models/mixtral/configuration_mixtral.py#L113
+#[derive(Debug, Clone, PartialEq, Deserialize)]
+pub struct Config {
+ pub(crate) vocab_size: usize,
+ pub(crate) hidden_size: usize,
+ pub(crate) intermediate_size: usize,
+ pub(crate) num_hidden_layers: usize,
+ pub(crate) num_attention_heads: usize,
+ pub(crate) num_key_value_heads: usize,
+ pub(crate) hidden_act: Activation,
+ pub(crate) max_position_embeddings: usize,
+ pub(crate) rms_norm_eps: f64,
+ pub(crate) rope_theta: f64,
+ pub(crate) sliding_window: usize,
+ pub(crate) num_experts_per_tok: usize,
+ pub(crate) num_local_experts: usize,
+ pub(crate) use_flash_attn: bool,
+}
+
+impl Config {
+ /// https://huggingface.co/mistralai/Mixtral-8x7B-v0.1/blob/main/config.json
+ pub fn v0_1_8x7b(use_flash_attn: bool) -> Self {
+ Self {
+ vocab_size: 32000,
+ hidden_size: 4096,
+ intermediate_size: 14336,
+ num_hidden_layers: 32,
+ num_attention_heads: 32,
+ num_key_value_heads: 8,
+ hidden_act: Activation::Silu,
+ max_position_embeddings: 32768,
+ rms_norm_eps: 1e-5,
+ rope_theta: 1e6,
+ sliding_window: 4096,
+ num_experts_per_tok: 2,
+ num_local_experts: 8,
+ use_flash_attn,
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+struct RmsNorm {
+ inner: candle_nn::RmsNorm,
+ span: tracing::Span,
+}
+
+impl RmsNorm {
+ fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
+ let inner = candle_nn::rms_norm(size, eps, vb)?;
+ Ok(Self { inner, span })
+ }
+}
+
+impl Module for RmsNorm {
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(x)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct RotaryEmbedding {
+ sin: Tensor,
+ cos: Tensor,
+}
+
+fn rotate_half(xs: &Tensor) -> Result<Tensor> {
+ let last_dim = xs.dim(D::Minus1)?;
+ let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
+ let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
+ Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
+}
+
+impl RotaryEmbedding {
+ fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
+ let dim = cfg.hidden_size / cfg.num_attention_heads;
+ let max_seq_len = cfg.max_position_embeddings;
+ let inv_freq: Vec<_> = (0..dim)
+ .step_by(2)
+ .map(|i| 1f32 / (cfg.rope_theta as f32).powf(i as f32 / dim as f32))
+ .collect();
+ let inv_freq_len = inv_freq.len();
+ let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
+ let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
+ .to_dtype(dtype)?
+ .reshape((max_seq_len, 1))?;
+ let freqs = t.matmul(&inv_freq)?;
+ let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
+ Ok(Self {
+ sin: freqs.sin()?,
+ cos: freqs.cos()?,
+ })
+ }
+
+ fn apply_rotary_emb_qkv(
+ &self,
+ q: &Tensor,
+ k: &Tensor,
+ seqlen_offset: usize,
+ ) -> Result<(Tensor, Tensor)> {
+ let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
+ let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
+ let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
+ let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
+ let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
+ let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
+ let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
+ Ok((q_embed, k_embed))
+ }
+}
+
+#[cfg(feature = "flash-attn")]
+fn flash_attn(
+ q: &Tensor,
+ k: &Tensor,
+ v: &Tensor,
+ softmax_scale: f32,
+ causal: bool,
+) -> Result<Tensor> {
+ candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
+}
+
+#[cfg(not(feature = "flash-attn"))]
+fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
+ unimplemented!("compile with '--features flash-attn'")
+}
+
+#[derive(Debug, Clone)]
+struct Attention {
+ q_proj: Linear,
+ k_proj: Linear,
+ v_proj: Linear,
+ o_proj: Linear,
+ num_heads: usize,
+ num_kv_heads: usize,
+ num_kv_groups: usize,
+ head_dim: usize,
+ hidden_size: usize,
+ rotary_emb: Arc<RotaryEmbedding>,
+ kv_cache: Option<(Tensor, Tensor)>,
+ use_flash_attn: bool,
+}
+
+impl Attention {
+ fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let hidden_sz = cfg.hidden_size;
+ let num_heads = cfg.num_attention_heads;
+ let num_kv_heads = cfg.num_key_value_heads;
+ let num_kv_groups = num_heads / num_kv_heads;
+ let head_dim = hidden_sz / num_heads;
+ let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
+ let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
+ let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
+ let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
+ Ok(Self {
+ q_proj,
+ k_proj,
+ v_proj,
+ o_proj,
+ num_heads,
+ num_kv_heads,
+ num_kv_groups,
+ head_dim,
+ hidden_size: hidden_sz,
+ rotary_emb,
+ kv_cache: None,
+ use_flash_attn: cfg.use_flash_attn,
+ })
+ }
+
+ fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
+ let n_rep = self.num_kv_groups;
+ if n_rep == 1 {
+ Ok(xs)
+ } else {
+ let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
+ xs.unsqueeze(2)?
+ .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
+ .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
+ }
+ }
+
+ fn forward(
+ &mut self,
+ xs: &Tensor,
+ attention_mask: Option<&Tensor>,
+ seqlen_offset: usize,
+ ) -> Result<Tensor> {
+ let (b_sz, q_len, _) = xs.dims3()?;
+
+ let query_states = self.q_proj.forward(xs)?;
+ let key_states = self.k_proj.forward(xs)?;
+ let value_states = self.v_proj.forward(xs)?;
+
+ let query_states = query_states
+ .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
+ .transpose(1, 2)?;
+ let key_states = key_states
+ .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
+ .transpose(1, 2)?;
+ let value_states = value_states
+ .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
+ .transpose(1, 2)?;
+
+ let (query_states, key_states) =
+ self.rotary_emb
+ .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
+
+ let (key_states, value_states) = match &self.kv_cache {
+ None => (key_states, value_states),
+ Some((prev_k, prev_v)) => {
+ let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
+ let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
+ (key_states, value_states)
+ }
+ };
+ self.kv_cache = Some((key_states.clone(), value_states.clone()));
+
+ let key_states = self.repeat_kv(key_states)?;
+ let value_states = self.repeat_kv(value_states)?;
+
+ let attn_output = if self.use_flash_attn {
+ // flash-attn expects (b_sz, seq_len, nheads, head_dim)
+ let q = query_states.transpose(1, 2)?;
+ let k = key_states.transpose(1, 2)?;
+ let v = value_states.transpose(1, 2)?;
+ let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
+ flash_attn(&q, &k, &v, softmax_scale, q_len > 1)?.transpose(1, 2)?
+ } else {
+ let scale = 1f64 / f64::sqrt(self.head_dim as f64);
+ let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
+
+ let attn_weights = match attention_mask {
+ None => attn_weights,
+ Some(mask) => attn_weights.broadcast_add(mask)?,
+ };
+ let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
+ attn_weights.matmul(&value_states)?
+ };
+ attn_output
+ .transpose(1, 2)?
+ .reshape((b_sz, q_len, self.hidden_size))?
+ .apply(&self.o_proj)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct BlockSparseTop2MLP {
+ w1: Linear,
+ w2: Linear,
+ w3: Linear,
+ act_fn: Activation,
+}
+
+impl BlockSparseTop2MLP {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let hidden_sz = cfg.hidden_size;
+ let intermediate_sz = cfg.intermediate_size;
+ let w1 = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("w1"))?;
+ let w2 = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("w2"))?;
+ let w3 = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("w3"))?;
+ Ok(Self {
+ w1,
+ w2,
+ w3,
+ act_fn: cfg.hidden_act,
+ })
+ }
+}
+
+impl Module for BlockSparseTop2MLP {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let lhs = xs.apply(&self.w1)?.apply(&self.act_fn)?;
+ let rhs = xs.apply(&self.w3)?;
+ (lhs * rhs)?.apply(&self.w2)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct SparseMoeBlock {
+ gate: Linear,
+ experts: Vec<BlockSparseTop2MLP>,
+ num_experts_per_tok: usize,
+}
+
+impl SparseMoeBlock {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let gate = linear_no_bias(cfg.hidden_size, cfg.num_local_experts, vb.pp("gate"))?;
+ let mut experts = Vec::with_capacity(cfg.num_local_experts);
+ let vb = vb.pp("experts");
+ for idx in 0..cfg.num_local_experts {
+ let expert = BlockSparseTop2MLP::new(cfg, vb.pp(idx))?;
+ experts.push(expert)
+ }
+ Ok(SparseMoeBlock {
+ gate,
+ experts,
+ num_experts_per_tok: cfg.num_experts_per_tok,
+ })
+ }
+}
+
+impl Module for SparseMoeBlock {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let (b_size, seq_len, hidden_dim) = xs.dims3()?;
+ let xs = xs.reshape(((), hidden_dim))?;
+ let router_logits = xs.apply(&self.gate)?;
+ let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
+
+ // In order to extract topk, we extract the data from the tensor and manipulate it
+ // directly. Maybe we will want to use some custom ops instead at some point.
+ let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
+
+ // routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
+ // top_x contains the row indexes to evaluate for each expert.
+ let mut top_x = vec![vec![]; self.experts.len()];
+ let mut selected_rws = vec![vec![]; self.experts.len()];
+ for (row_idx, rw) in routing_weights.iter().enumerate() {
+ let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
+ dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
+ let mut sum_routing_weights = 0f32;
+ for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
+ let expert_idx = expert_idx as usize;
+ let routing_weight = rw[expert_idx];
+ sum_routing_weights += routing_weight;
+ top_x[expert_idx].push(row_idx as u32);
+ }
+ for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
+ let expert_idx = expert_idx as usize;
+ let routing_weight = rw[expert_idx];
+ selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
+ }
+ }
+
+ // routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
+ // expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
+
+ let mut ys = xs.zeros_like()?;
+ for (expert_idx, expert_layer) in self.experts.iter().enumerate() {
+ let top_x = &top_x[expert_idx];
+ if top_x.is_empty() {
+ continue;
+ }
+ let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
+ let selected_rws =
+ Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?.reshape(((), 1))?;
+ // Index the correct hidden states and compute the expert hidden state for
+ // the current expert. We need to make sure to multiply the output hidden
+ // states by `routing_weights` on the corresponding tokens (top-1 and top-2)
+ let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
+ // current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None])
+ let current_hidden_states = expert_layer.forward(&current_state)?;
+ let current_hidden_states = current_hidden_states.broadcast_mul(&selected_rws)?;
+ ys = ys.index_add(&top_x, &current_hidden_states, 0)?;
+ }
+
+ let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
+ Ok(ys)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct DecoderLayer {
+ self_attn: Attention,
+ block_sparse_moe: SparseMoeBlock,
+ input_layernorm: RmsNorm,
+ post_attention_layernorm: RmsNorm,
+}
+
+impl DecoderLayer {
+ fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
+ let block_sparse_moe = SparseMoeBlock::new(cfg, vb.pp("block_sparse_moe"))?;
+ let input_layernorm =
+ RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
+ let post_attention_layernorm = RmsNorm::new(
+ cfg.hidden_size,
+ cfg.rms_norm_eps,
+ vb.pp("post_attention_layernorm"),
+ )?;
+ Ok(Self {
+ self_attn,
+ block_sparse_moe,
+ input_layernorm,
+ post_attention_layernorm,
+ })
+ }
+
+ fn forward(
+ &mut self,
+ xs: &Tensor,
+ attention_mask: Option<&Tensor>,
+ seqlen_offset: usize,
+ ) -> Result<Tensor> {
+ let residual = xs;
+ let xs = self.input_layernorm.forward(xs)?;
+ let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
+ let xs = (xs + residual)?;
+ let residual = &xs;
+ let xs = xs
+ .apply(&self.post_attention_layernorm)?
+ .apply(&self.block_sparse_moe)?;
+ residual + xs
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct Model {
+ embed_tokens: candle_nn::Embedding,
+ layers: Vec<DecoderLayer>,
+ norm: RmsNorm,
+ lm_head: Linear,
+ sliding_window: usize,
+ device: Device,
+ dtype: DType,
+}
+
+impl Model {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let vb_m = vb.pp("model");
+ let embed_tokens =
+ candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
+ let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
+ let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
+ let vb_l = vb_m.pp("layers");
+ for layer_idx in 0..cfg.num_hidden_layers {
+ let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
+ layers.push(layer)
+ }
+ let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
+ let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
+ Ok(Self {
+ embed_tokens,
+ layers,
+ norm,
+ lm_head,
+ sliding_window: cfg.sliding_window,
+ device: vb.device().clone(),
+ dtype: vb.dtype(),
+ })
+ }
+
+ fn prepare_decoder_attention_mask(
+ &self,
+ b_size: usize,
+ tgt_len: usize,
+ seqlen_offset: usize,
+ ) -> Result<Tensor> {
+ // Sliding window mask?
+ let mask: Vec<_> = (0..tgt_len)
+ .flat_map(|i| {
+ (0..tgt_len).map(move |j| {
+ if i < j || j + self.sliding_window < i {
+ f32::NEG_INFINITY
+ } else {
+ 0.
+ }
+ })
+ })
+ .collect();
+ let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
+ let mask = if seqlen_offset > 0 {
+ let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
+ Tensor::cat(&[&mask0, &mask], D::Minus1)?
+ } else {
+ mask
+ };
+ mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
+ .to_dtype(self.dtype)
+ }
+
+ pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
+ let (b_size, seq_len) = input_ids.dims2()?;
+ let attention_mask = if seq_len <= 1 {
+ None
+ } else {
+ let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
+ Some(mask)
+ };
+ let mut xs = self.embed_tokens.forward(input_ids)?;
+ for layer in self.layers.iter_mut() {
+ xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
+ }
+ xs.narrow(1, seq_len - 1, 1)?
+ .apply(&self.norm)?
+ .apply(&self.lm_head)
+ }
+}
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index a9a56673..94a3bd5b 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -14,6 +14,7 @@ pub mod llama2_c_weights;
pub mod marian;
pub mod mistral;
pub mod mixformer;
+pub mod mixtral;
pub mod mpt;
pub mod persimmon;
pub mod quantized_blip;