summaryrefslogtreecommitdiff
path: root/candle-examples/examples/quantized/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/quantized/main.rs')
-rw-r--r--candle-examples/examples/quantized/main.rs36
1 files changed, 31 insertions, 5 deletions
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs
index 8411142e..477c695f 100644
--- a/candle-examples/examples/quantized/main.rs
+++ b/candle-examples/examples/quantized/main.rs
@@ -68,6 +68,7 @@ struct LayerWeights {
feed_forward_w3: QMatMul,
ffn_norm: RmsNorm,
n_head: usize,
+ n_kv_head: usize,
head_dim: usize,
cos: Tensor,
sin: Tensor,
@@ -125,10 +126,10 @@ impl LayerWeights {
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
.transpose(1, 2)?;
let k = k
- .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
+ .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
.transpose(1, 2)?;
let v = v
- .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
+ .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
.transpose(1, 2)?;
let q = self.apply_rotary_emb(&q, index_pos)?;
@@ -144,7 +145,9 @@ impl LayerWeights {
};
self.kv_cache = Some((k.clone(), v.clone()));
- // If we start supporting MQA, we need to repeat the k and v tensors here.
+ // Support for MQA, useful for 70B models.
+ let k = self.repeat_kv(k)?;
+ let v = self.repeat_kv(v)?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = mask.broadcast_as(att.shape())?;
@@ -156,6 +159,20 @@ impl LayerWeights {
let y = self.attention_wo.forward(&y)?;
Ok(y)
}
+
+ fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
+ let n_rep = self.n_head / self.n_kv_head;
+ if n_rep == 1 {
+ Ok(x)
+ } else {
+ let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
+ let x = x
+ .unsqueeze(2)?
+ .expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
+ .reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?;
+ Ok(x)
+ }
+ }
}
struct ModelWeights {
@@ -179,7 +196,7 @@ impl WeightMap {
}
impl ModelWeights {
- fn new(mut ct: Content) -> Result<Self> {
+ fn new(mut ct: Content, gqa: usize) -> Result<Self> {
let cpu = &Device::Cpu;
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
@@ -226,6 +243,7 @@ impl ModelWeights {
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3),
ffn_norm: RmsNorm::new(ffn_norm)?,
n_head: ct.hparams.n_head as usize,
+ n_kv_head: ct.hparams.n_head as usize / gqa,
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
cos: cos.clone(),
sin: sin.clone(),
@@ -347,6 +365,10 @@ struct Args {
/// The model size to use.
#[arg(long, default_value = "7b")]
which: Which,
+
+ /// Group-Query Attention, use 8 for the 70B version of LLaMAv2.
+ #[arg(long)]
+ gqa: Option<usize>,
}
impl Args {
@@ -468,7 +490,11 @@ fn main() -> anyhow::Result<()> {
start.elapsed().as_secs_f32(),
);
println!("params: {:?}", model.hparams);
- let mut model = ModelWeights::new(model)?;
+ let default_gqa = match args.which {
+ Which::L7b | Which::L13b => 1,
+ Which::L70b => 8,
+ };
+ let mut model = ModelWeights::new(model, args.gqa.unwrap_or(default_gqa))?;
println!("model built");
let tokenizer = args.tokenizer()?;