diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-04-20 22:19:46 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-20 22:19:46 +0200 |
commit | 587ee3bb6fd2b4c2b7bbe7e97751cac96249dd6d (patch) | |
tree | 122db3c84eccdef1dd0451b0c939104ab03a4113 /candle-transformers | |
parent | dd78422701e9c6f3ca74218e8aedcf032c6c7215 (diff) | |
download | candle-587ee3bb6fd2b4c2b7bbe7e97751cac96249dd6d.tar.gz candle-587ee3bb6fd2b4c2b7bbe7e97751cac96249dd6d.tar.bz2 candle-587ee3bb6fd2b4c2b7bbe7e97751cac96249dd6d.zip |
Small cleanups to the llama multi-process example. (#2098)
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/llama.rs | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index 945c0e17..57d2f593 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -20,6 +20,12 @@ pub struct LlamaConfig { pub eos_token_id: Option<u32>, } +impl LlamaConfig { + pub fn num_key_value_heads(&self) -> usize { + self.num_key_value_heads.unwrap_or(self.num_attention_heads) + } +} + fn default_rope() -> f32 { 10_000.0 } @@ -32,7 +38,7 @@ impl LlamaConfig { vocab_size: self.vocab_size, num_hidden_layers: self.num_hidden_layers, num_attention_heads: self.num_attention_heads, - num_key_value_heads: self.num_key_value_heads.unwrap_or(self.num_attention_heads), + num_key_value_heads: self.num_key_value_heads(), rms_norm_eps: self.rms_norm_eps, rope_theta: self.rope_theta, use_flash_attn, |