summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-20 22:19:46 +0200
committerGitHub <noreply@github.com>2024-04-20 22:19:46 +0200
commit587ee3bb6fd2b4c2b7bbe7e97751cac96249dd6d (patch)
tree122db3c84eccdef1dd0451b0c939104ab03a4113 /candle-transformers
parentdd78422701e9c6f3ca74218e8aedcf032c6c7215 (diff)
downloadcandle-587ee3bb6fd2b4c2b7bbe7e97751cac96249dd6d.tar.gz
candle-587ee3bb6fd2b4c2b7bbe7e97751cac96249dd6d.tar.bz2
candle-587ee3bb6fd2b4c2b7bbe7e97751cac96249dd6d.zip
Small cleanups to the llama multi-process example. (#2098)
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/llama.rs8
1 files changed, 7 insertions, 1 deletions
diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs
index 945c0e17..57d2f593 100644
--- a/candle-transformers/src/models/llama.rs
+++ b/candle-transformers/src/models/llama.rs
@@ -20,6 +20,12 @@ pub struct LlamaConfig {
pub eos_token_id: Option<u32>,
}
+impl LlamaConfig {
+ pub fn num_key_value_heads(&self) -> usize {
+ self.num_key_value_heads.unwrap_or(self.num_attention_heads)
+ }
+}
+
fn default_rope() -> f32 {
10_000.0
}
@@ -32,7 +38,7 @@ impl LlamaConfig {
vocab_size: self.vocab_size,
num_hidden_layers: self.num_hidden_layers,
num_attention_heads: self.num_attention_heads,
- num_key_value_heads: self.num_key_value_heads.unwrap_or(self.num_attention_heads),
+ num_key_value_heads: self.num_key_value_heads(),
rms_norm_eps: self.rms_norm_eps,
rope_theta: self.rope_theta,
use_flash_attn,