summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-22 19:41:10 +0100
committerGitHub <noreply@github.com>2023-08-22 19:41:10 +0100
commitf9ecc8447753d759e776e762ba9309bb90b76bb3 (patch)
tree311d0e2f4dad33ea8174225cc1bfa5bf429ba713 /candle-examples/examples/llama
parent07067b01dce3c63b45fe4bdeb8d972f279e88b45 (diff)
downloadcandle-f9ecc8447753d759e776e762ba9309bb90b76bb3.tar.gz
candle-f9ecc8447753d759e776e762ba9309bb90b76bb3.tar.bz2
candle-f9ecc8447753d759e776e762ba9309bb90b76bb3.zip
GQA support in the quantized model. (#555)
* GQA support in the quantized model. * Fix the reshaping. * Fix the main llama model. * Infer the proper gqa from the model kind.
Diffstat (limited to 'candle-examples/examples/llama')
-rw-r--r--candle-examples/examples/llama/model.rs2
1 files changed, 1 insertions, 1 deletions
diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs
index 86d13bdb..561c2939 100644
--- a/candle-examples/examples/llama/model.rs
+++ b/candle-examples/examples/llama/model.rs
@@ -291,7 +291,7 @@ impl CausalSelfAttention {
let x = x
.unsqueeze(2)?
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
- .reshape((b_sz, n_kv_head, n_rep, seq_len, head_dim))?;
+ .reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?;
Ok(x)
}
}