diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-09-26 21:00:18 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-26 21:00:18 +0200 |
commit | ad8a4c5e5a2f177e3b90ac812d66d6a5ed1d69dc (patch) | |
tree | d6514faca57dd0204170e04d6d6a94ca295fe278 /candle-examples/examples | |
parent | c3c392f45c14f60eb4fb8397cc5c1d3891c9656d (diff) | |
download | candle-ad8a4c5e5a2f177e3b90ac812d66d6a5ed1d69dc.tar.gz candle-ad8a4c5e5a2f177e3b90ac812d66d6a5ed1d69dc.tar.bz2 candle-ad8a4c5e5a2f177e3b90ac812d66d6a5ed1d69dc.zip |
Add some llama-3.2 examples. (#2508)
* Add some llama-3.2 examples.
* Support tie-word-embeddings for llama.
Diffstat (limited to 'candle-examples/examples')
-rw-r--r-- | candle-examples/examples/llama/main.rs | 14 |
1 files changed, 13 insertions, 1 deletions
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index 93f1e508..7a555b00 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -35,6 +35,10 @@ enum Which { V31, V3Instruct, V31Instruct, + V32_1b, + V32_1bInstruct, + V32_3b, + V32_3bInstruct, #[value(name = "solar-10.7b")] Solar10_7B, #[value(name = "tiny-llama-1.1b-chat")] @@ -137,6 +141,10 @@ fn main() -> Result<()> { Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(), Which::V31 => "meta-llama/Meta-Llama-3.1-8B".to_string(), Which::V31Instruct => "meta-llama/Meta-Llama-3.1-8B-Instruct".to_string(), + Which::V32_1b => "meta-llama/Llama-3.2-1B".to_string(), + Which::V32_1bInstruct => "meta-llama/Llama-3.2-1B-Instruct".to_string(), + Which::V32_3b => "meta-llama/Llama-3.2-3B".to_string(), + Which::V32_3bInstruct => "meta-llama/Llama-3.2-3B-Instruct".to_string(), Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(), Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(), }); @@ -156,10 +164,14 @@ fn main() -> Result<()> { | Which::V3Instruct | Which::V31 | Which::V31Instruct + | Which::V32_3b + | Which::V32_3bInstruct | Which::Solar10_7B => { candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")? } - Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?], + Which::V32_1b | Which::V32_1bInstruct | Which::TinyLlama1_1BChat => { + vec![api.get("model.safetensors")?] + } }; let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?; |