summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-18 22:19:54 +0200
committerGitHub <noreply@github.com>2024-04-18 22:19:54 +0200
commite6ee7ba4d46de6e5e1e003319da4a49a3a7a0813 (patch)
tree85942a2f7d5710e5e78860566b07c06d16d0be0d /candle-transformers
parent1690ab45d2f636bac256bf101c65eb6fa0a1165a (diff)
downloadcandle-e6ee7ba4d46de6e5e1e003319da4a49a3a7a0813.tar.gz
candle-e6ee7ba4d46de6e5e1e003319da4a49a3a7a0813.tar.bz2
candle-e6ee7ba4d46de6e5e1e003319da4a49a3a7a0813.zip
Llama v3. (#2085)
* Llama v3. * Tweak the default params + handle special tokens. * Small tweak.
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/llama.rs10
1 files changed, 10 insertions, 0 deletions
diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs
index f3d482eb..97a40d37 100644
--- a/candle-transformers/src/models/llama.rs
+++ b/candle-transformers/src/models/llama.rs
@@ -16,6 +16,8 @@ pub struct LlamaConfig {
pub rms_norm_eps: f64,
#[serde(default = "default_rope")]
pub rope_theta: f32,
+ pub bos_token_id: Option<u32>,
+ pub eos_token_id: Option<u32>,
}
fn default_rope() -> f32 {
@@ -34,6 +36,8 @@ impl LlamaConfig {
rms_norm_eps: self.rms_norm_eps,
rope_theta: self.rope_theta,
use_flash_attn,
+ bos_token_id: self.bos_token_id,
+ eos_token_id: self.eos_token_id,
}
}
}
@@ -49,6 +53,8 @@ pub struct Config {
pub use_flash_attn: bool,
pub rms_norm_eps: f64,
pub rope_theta: f32,
+ pub bos_token_id: Option<u32>,
+ pub eos_token_id: Option<u32>,
}
impl Config {
@@ -63,6 +69,8 @@ impl Config {
use_flash_attn,
rms_norm_eps: 1e-6,
rope_theta: 10_000.0,
+ bos_token_id: None,
+ eos_token_id: None,
}
}
@@ -77,6 +85,8 @@ impl Config {
use_flash_attn,
rms_norm_eps: 1e-5,
rope_theta: 10_000.0,
+ bos_token_id: None,
+ eos_token_id: None,
}
}
}