summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/quantized_llama.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/quantized_llama.rs')
-rw-r--r--candle-transformers/src/models/quantized_llama.rs171
1 files changed, 150 insertions, 21 deletions
diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs
index 44d89f40..1fb2d9e2 100644
--- a/candle-transformers/src/models/quantized_llama.rs
+++ b/candle-transformers/src/models/quantized_llama.rs
@@ -48,15 +48,109 @@ impl QMatMul {
}
#[derive(Debug, Clone)]
+struct Mlp {
+ feed_forward_w1: QMatMul,
+ feed_forward_w2: QMatMul,
+ feed_forward_w3: QMatMul,
+}
+
+impl Module for Mlp {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let w1 = self.feed_forward_w1.forward(xs)?;
+ let w3 = self.feed_forward_w3.forward(xs)?;
+ self.feed_forward_w2
+ .forward(&(candle_nn::ops::silu(&w1)? * w3)?)
+ }
+}
+
+#[derive(Debug, Clone)]
+enum MlpOrMoe {
+ Mlp(Mlp),
+ MoE {
+ n_expert_used: usize,
+ feed_forward_gate_inp: QMatMul,
+ experts: Vec<Mlp>,
+ },
+}
+
+impl Module for MlpOrMoe {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ match self {
+ Self::MoE {
+ feed_forward_gate_inp,
+ experts,
+ n_expert_used,
+ } => {
+ let (b_size, seq_len, hidden_dim) = xs.dims3()?;
+ let xs = xs.reshape(((), hidden_dim))?;
+ let router_logits = feed_forward_gate_inp.forward(&xs)?;
+ let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
+
+ // In order to extract topk, we extract the data from the tensor and manipulate it
+ // directly. Maybe we will want to use some custom ops instead at some point.
+ let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
+
+ // routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
+ // top_x contains the row indexes to evaluate for each expert.
+ let mut top_x = vec![vec![]; experts.len()];
+ let mut selected_rws = vec![vec![]; experts.len()];
+ for (row_idx, rw) in routing_weights.iter().enumerate() {
+ let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
+ dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
+ let mut sum_routing_weights = 0f32;
+ for &expert_idx in dst.iter().take(*n_expert_used) {
+ let expert_idx = expert_idx as usize;
+ let routing_weight = rw[expert_idx];
+ sum_routing_weights += routing_weight;
+ top_x[expert_idx].push(row_idx as u32);
+ }
+ for &expert_idx in dst.iter().take(*n_expert_used) {
+ let expert_idx = expert_idx as usize;
+ let routing_weight = rw[expert_idx];
+ selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
+ }
+ }
+
+ // routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
+ // expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
+
+ let mut ys = xs.zeros_like()?;
+ for (expert_idx, expert_layer) in experts.iter().enumerate() {
+ let top_x = &top_x[expert_idx];
+ if top_x.is_empty() {
+ continue;
+ }
+ let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
+ let selected_rws =
+ Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?
+ .reshape(((), 1))?;
+ // Index the correct hidden states and compute the expert hidden state for
+ // the current expert. We need to make sure to multiply the output hidden
+ // states by `routing_weights` on the corresponding tokens (top-1 and top-2)
+ let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
+ // current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None])
+ let current_hidden_states = expert_layer.forward(&current_state)?;
+ let current_hidden_states =
+ current_hidden_states.broadcast_mul(&selected_rws)?;
+ ys = ys.index_add(&top_x, &current_hidden_states, 0)?;
+ }
+
+ let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
+ Ok(ys)
+ }
+ Self::Mlp(mlp) => mlp.forward(xs),
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
struct LayerWeights {
attention_wq: QMatMul,
attention_wk: QMatMul,
attention_wv: QMatMul,
attention_wo: QMatMul,
attention_norm: RmsNorm,
- feed_forward_w1: QMatMul,
- feed_forward_w2: QMatMul,
- feed_forward_w3: QMatMul,
+ mlp_or_moe: MlpOrMoe,
ffn_norm: RmsNorm,
n_head: usize,
n_kv_head: usize,
@@ -212,9 +306,16 @@ impl ModelWeights {
let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?;
let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?;
let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?;
- let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?;
- let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?;
- let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
+ let mlp_or_moe = {
+ let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?;
+ let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?;
+ let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
+ MlpOrMoe::Mlp(Mlp {
+ feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
+ feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
+ feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
+ })
+ };
let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?;
let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?;
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
@@ -226,9 +327,7 @@ impl ModelWeights {
attention_wv: QMatMul::from_qtensor(attention_wv)?,
attention_wo: QMatMul::from_qtensor(attention_wo)?,
attention_norm: RmsNorm::new(attention_norm, 1e-5)?,
- feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
- feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
- feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
+ mlp_or_moe,
ffn_norm: RmsNorm::new(ffn_norm, 1e-5)?,
n_head: ct.hparams.n_head as usize,
n_kv_head: ct.hparams.n_head as usize / gqa,
@@ -265,6 +364,12 @@ impl ModelWeights {
};
// Parameter extraction from metadata.
+ let n_expert = md_get("llama.expert_count")
+ .and_then(|v| v.to_u32())
+ .unwrap_or(0) as usize;
+ let n_expert_used = md_get("llama.expert_used_count")
+ .and_then(|v| v.to_u32())
+ .unwrap_or(0) as usize;
let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize;
let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize;
let block_count = md_get("llama.block_count")?.to_u32()? as usize;
@@ -289,9 +394,38 @@ impl ModelWeights {
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?;
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?;
let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?;
- let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?;
- let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?;
- let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?;
+ let mlp_or_moe = if n_expert <= 1 {
+ let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?;
+ let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?;
+ let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?;
+ MlpOrMoe::Mlp(Mlp {
+ feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
+ feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
+ feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
+ })
+ } else {
+ let feed_forward_gate_inp =
+ ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"))?;
+ let mut experts = Vec::with_capacity(n_expert);
+ for i in 0..n_expert {
+ let feed_forward_w1 =
+ ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"))?;
+ let feed_forward_w2 =
+ ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"))?;
+ let feed_forward_w3 =
+ ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"))?;
+ experts.push(Mlp {
+ feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
+ feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
+ feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
+ })
+ }
+ MlpOrMoe::MoE {
+ n_expert_used,
+ feed_forward_gate_inp: QMatMul::from_qtensor(feed_forward_gate_inp)?,
+ experts,
+ }
+ };
let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?;
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?;
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
@@ -303,9 +437,7 @@ impl ModelWeights {
attention_wv: QMatMul::from_qtensor(attention_wv)?,
attention_wo: QMatMul::from_qtensor(attention_wo)?,
attention_norm: RmsNorm::new(attention_norm, rms_norm_eps)?,
- feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
- feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
- feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
+ mlp_or_moe,
ffn_norm: RmsNorm::new(ffn_norm, rms_norm_eps)?,
n_head: head_count,
n_kv_head: head_count_kv,
@@ -360,12 +492,9 @@ impl ModelWeights {
let _enter = layer.span_mlp.enter();
let residual = &x;
let x = layer.ffn_norm.forward(&x)?;
- let w1 = layer.feed_forward_w1.forward(&x)?;
- let w3 = layer.feed_forward_w3.forward(&x)?;
- let mlp = layer
- .feed_forward_w2
- .forward(&(candle_nn::ops::silu(&w1)? * w3)?)?;
- layer_in = (mlp + residual)?;
+ let x = layer.mlp_or_moe.forward(&x)?;
+ let x = (x + residual)?;
+ layer_in = x
}
let x = self.norm.forward(&layer_in)?;
let x = x.i((.., seq_len - 1, ..))?;