summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/llama/model.rs4
1 files changed, 2 insertions, 2 deletions
diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs
index 940c980c..751b5902 100644
--- a/candle-examples/examples/llama/model.rs
+++ b/candle-examples/examples/llama/model.rs
@@ -13,7 +13,7 @@ pub struct LlamaConfig {
pub vocab_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
- pub num_key_value_heads: usize,
+ pub num_key_value_heads: Option<usize>,
pub rms_norm_eps: f64,
}
@@ -25,7 +25,7 @@ impl LlamaConfig {
vocab_size: self.vocab_size,
num_hidden_layers: self.num_hidden_layers,
num_attention_heads: self.num_attention_heads,
- num_key_value_heads: self.num_key_value_heads,
+ num_key_value_heads: self.num_key_value_heads.unwrap_or(self.num_attention_heads),
rms_norm_eps: self.rms_norm_eps,
use_flash_attn,
}