diff options
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/llama.rs | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index f3d482eb..97a40d37 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -16,6 +16,8 @@ pub struct LlamaConfig { pub rms_norm_eps: f64, #[serde(default = "default_rope")] pub rope_theta: f32, + pub bos_token_id: Option<u32>, + pub eos_token_id: Option<u32>, } fn default_rope() -> f32 { @@ -34,6 +36,8 @@ impl LlamaConfig { rms_norm_eps: self.rms_norm_eps, rope_theta: self.rope_theta, use_flash_attn, + bos_token_id: self.bos_token_id, + eos_token_id: self.eos_token_id, } } } @@ -49,6 +53,8 @@ pub struct Config { pub use_flash_attn: bool, pub rms_norm_eps: f64, pub rope_theta: f32, + pub bos_token_id: Option<u32>, + pub eos_token_id: Option<u32>, } impl Config { @@ -63,6 +69,8 @@ impl Config { use_flash_attn, rms_norm_eps: 1e-6, rope_theta: 10_000.0, + bos_token_id: None, + eos_token_id: None, } } @@ -77,6 +85,8 @@ impl Config { use_flash_attn, rms_norm_eps: 1e-5, rope_theta: 10_000.0, + bos_token_id: None, + eos_token_id: None, } } } |