summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/llama/main.rs6
-rw-r--r--candle-examples/examples/llama/model.rs16
2 files changed, 20 insertions, 2 deletions
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs
index d9d1e21a..f3cf17bc 100644
--- a/candle-examples/examples/llama/main.rs
+++ b/candle-examples/examples/llama/main.rs
@@ -127,7 +127,11 @@ fn main() -> Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
- let config = Config::config_7b(args.use_flash_attn);
+ let config = if args.v1 {
+ Config::config_7b_v1(args.use_flash_attn)
+ } else {
+ Config::config_7b_v2(args.use_flash_attn)
+ };
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let (llama, tokenizer_filename) = match args.npy {
diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs
index efb9aeef..dba1d535 100644
--- a/candle-examples/examples/llama/model.rs
+++ b/candle-examples/examples/llama/model.rs
@@ -18,7 +18,21 @@ pub struct Config {
}
impl Config {
- pub fn config_7b(use_flash_attn: bool) -> Self {
+ pub fn config_7b_v1(use_flash_attn: bool) -> Self {
+ Self {
+ hidden_size: 4096,
+ intermediate_size: 11008,
+ vocab_size: 32000,
+ n_layer: 32,
+ n_head: 32,
+ n_embd: 4096,
+ n_key_value_head: 32,
+ use_flash_attn,
+ rms_norm_eps: 1e-6,
+ }
+ }
+
+ pub fn config_7b_v2(use_flash_attn: bool) -> Self {
Self {
hidden_size: 4096,
intermediate_size: 11008,