diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-04-18 22:19:54 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-18 22:19:54 +0200 |
commit | e6ee7ba4d46de6e5e1e003319da4a49a3a7a0813 (patch) | |
tree | 85942a2f7d5710e5e78860566b07c06d16d0be0d /candle-transformers | |
parent | 1690ab45d2f636bac256bf101c65eb6fa0a1165a (diff) | |
download | candle-e6ee7ba4d46de6e5e1e003319da4a49a3a7a0813.tar.gz candle-e6ee7ba4d46de6e5e1e003319da4a49a3a7a0813.tar.bz2 candle-e6ee7ba4d46de6e5e1e003319da4a49a3a7a0813.zip |
Llama v3. (#2085)
* Llama v3.
* Tweak the default params + handle special tokens.
* Small tweak.
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/llama.rs | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index f3d482eb..97a40d37 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -16,6 +16,8 @@ pub struct LlamaConfig { pub rms_norm_eps: f64, #[serde(default = "default_rope")] pub rope_theta: f32, + pub bos_token_id: Option<u32>, + pub eos_token_id: Option<u32>, } fn default_rope() -> f32 { @@ -34,6 +36,8 @@ impl LlamaConfig { rms_norm_eps: self.rms_norm_eps, rope_theta: self.rope_theta, use_flash_attn, + bos_token_id: self.bos_token_id, + eos_token_id: self.eos_token_id, } } } @@ -49,6 +53,8 @@ pub struct Config { pub use_flash_attn: bool, pub rms_norm_eps: f64, pub rope_theta: f32, + pub bos_token_id: Option<u32>, + pub eos_token_id: Option<u32>, } impl Config { @@ -63,6 +69,8 @@ impl Config { use_flash_attn, rms_norm_eps: 1e-6, rope_theta: 10_000.0, + bos_token_id: None, + eos_token_id: None, } } @@ -77,6 +85,8 @@ impl Config { use_flash_attn, rms_norm_eps: 1e-5, rope_theta: 10_000.0, + bos_token_id: None, + eos_token_id: None, } } } |