summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama/model.rs
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2023-08-16 17:47:40 +0200
committerGitHub <noreply@github.com>2023-08-16 17:47:40 +0200
commitfa4590d7fd2b21ee811dba735851b6ec487f3cee (patch)
tree6f26ab7595b1c9d4cff34d117682def994cd5c91 /candle-examples/examples/llama/model.rs
parent2e206e269da311cb0c3bde164e6c2ecb9286034e (diff)
parent102fa4c2e3e833199517a9400d0c2310ce18d62e (diff)
downloadcandle-fa4590d7fd2b21ee811dba735851b6ec487f3cee.tar.gz
candle-fa4590d7fd2b21ee811dba735851b6ec487f3cee.tar.bz2
candle-fa4590d7fd2b21ee811dba735851b6ec487f3cee.zip
Merge pull request #469 from huggingface/fix_llama_v1
Fixing llamav1
Diffstat (limited to 'candle-examples/examples/llama/model.rs')
-rw-r--r--candle-examples/examples/llama/model.rs4
1 files changed, 2 insertions, 2 deletions
diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs
index 940c980c..751b5902 100644
--- a/candle-examples/examples/llama/model.rs
+++ b/candle-examples/examples/llama/model.rs
@@ -13,7 +13,7 @@ pub struct LlamaConfig {
pub vocab_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
- pub num_key_value_heads: usize,
+ pub num_key_value_heads: Option<usize>,
pub rms_norm_eps: f64,
}
@@ -25,7 +25,7 @@ impl LlamaConfig {
vocab_size: self.vocab_size,
num_hidden_layers: self.num_hidden_layers,
num_attention_heads: self.num_attention_heads,
- num_key_value_heads: self.num_key_value_heads,
+ num_key_value_heads: self.num_key_value_heads.unwrap_or(self.num_attention_heads),
rms_norm_eps: self.rms_norm_eps,
use_flash_attn,
}