summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-19 08:50:06 +0200
committerGitHub <noreply@github.com>2024-04-19 08:50:06 +0200
commit9c532aef4751ad33cb74bb81b506cdb3011b5bef (patch)
tree1773f3dcdec38376ea9faf965c2e7228a47595ed /candle-examples/examples/llama
parentf7a646823896fefa89dfdd7138c34eeb8e0a5336 (diff)
downloadcandle-9c532aef4751ad33cb74bb81b506cdb3011b5bef.tar.gz
candle-9c532aef4751ad33cb74bb81b506cdb3011b5bef.tar.bz2
candle-9c532aef4751ad33cb74bb81b506cdb3011b5bef.zip
Also enable llama-v3 8b instruct. (#2088)
Diffstat (limited to 'candle-examples/examples/llama')
-rw-r--r--candle-examples/examples/llama/main.rs4
1 files changed, 3 insertions, 1 deletions
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs
index dbff1b7d..32763153 100644
--- a/candle-examples/examples/llama/main.rs
+++ b/candle-examples/examples/llama/main.rs
@@ -32,6 +32,7 @@ enum Which {
V1,
V2,
V3,
+ V3Instruct,
#[value(name = "solar-10.7b")]
Solar10_7B,
#[value(name = "tiny-llama-1.1b-chat")]
@@ -127,6 +128,7 @@ fn main() -> Result<()> {
Which::V1 => "Narsil/amall-7b".to_string(),
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(),
+ Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(),
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
});
@@ -140,7 +142,7 @@ fn main() -> Result<()> {
let config = config.into_config(args.use_flash_attn);
let filenames = match args.which {
- Which::V1 | Which::V2 | Which::V3 | Which::Solar10_7B => {
+ Which::V1 | Which::V2 | Which::V3 | Which::V3Instruct | Which::Solar10_7B => {
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
}
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],