summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/llama.rs10
1 files changed, 10 insertions, 0 deletions
diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs
index f3d482eb..97a40d37 100644
--- a/candle-transformers/src/models/llama.rs
+++ b/candle-transformers/src/models/llama.rs
@@ -16,6 +16,8 @@ pub struct LlamaConfig {
pub rms_norm_eps: f64,
#[serde(default = "default_rope")]
pub rope_theta: f32,
+ pub bos_token_id: Option<u32>,
+ pub eos_token_id: Option<u32>,
}
fn default_rope() -> f32 {
@@ -34,6 +36,8 @@ impl LlamaConfig {
rms_norm_eps: self.rms_norm_eps,
rope_theta: self.rope_theta,
use_flash_attn,
+ bos_token_id: self.bos_token_id,
+ eos_token_id: self.eos_token_id,
}
}
}
@@ -49,6 +53,8 @@ pub struct Config {
pub use_flash_attn: bool,
pub rms_norm_eps: f64,
pub rope_theta: f32,
+ pub bos_token_id: Option<u32>,
+ pub eos_token_id: Option<u32>,
}
impl Config {
@@ -63,6 +69,8 @@ impl Config {
use_flash_attn,
rms_norm_eps: 1e-6,
rope_theta: 10_000.0,
+ bos_token_id: None,
+ eos_token_id: None,
}
}
@@ -77,6 +85,8 @@ impl Config {
use_flash_attn,
rms_norm_eps: 1e-5,
rope_theta: 10_000.0,
+ bos_token_id: None,
+ eos_token_id: None,
}
}
}