summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama/model.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-28 18:40:59 +0100
committerGitHub <noreply@github.com>2023-07-28 18:40:59 +0100
commit50d8273ae4692e040045dbc8fca09f261fa8c237 (patch)
tree4890f057cb80bcf3934e9d15cb65d3405661cf49 /candle-examples/examples/llama/model.rs
parent7513a5e005bfa7e205345aaeeb6f660cf178a598 (diff)
downloadcandle-50d8273ae4692e040045dbc8fca09f261fa8c237.tar.gz
candle-50d8273ae4692e040045dbc8fca09f261fa8c237.tar.bz2
candle-50d8273ae4692e040045dbc8fca09f261fa8c237.zip
Support both llama v1 and llama v2. (#272)
Diffstat (limited to 'candle-examples/examples/llama/model.rs')
-rw-r--r--candle-examples/examples/llama/model.rs16
1 files changed, 15 insertions, 1 deletions
diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs
index efb9aeef..dba1d535 100644
--- a/candle-examples/examples/llama/model.rs
+++ b/candle-examples/examples/llama/model.rs
@@ -18,7 +18,21 @@ pub struct Config {
}
impl Config {
- pub fn config_7b(use_flash_attn: bool) -> Self {
+ pub fn config_7b_v1(use_flash_attn: bool) -> Self {
+ Self {
+ hidden_size: 4096,
+ intermediate_size: 11008,
+ vocab_size: 32000,
+ n_layer: 32,
+ n_head: 32,
+ n_embd: 4096,
+ n_key_value_head: 32,
+ use_flash_attn,
+ rms_norm_eps: 1e-6,
+ }
+ }
+
+ pub fn config_7b_v2(use_flash_attn: bool) -> Self {
Self {
hidden_size: 4096,
intermediate_size: 11008,