summaryrefslogtreecommitdiff
path: root/candle-examples/examples
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-09-26 21:00:18 +0200
committerGitHub <noreply@github.com>2024-09-26 21:00:18 +0200
commitad8a4c5e5a2f177e3b90ac812d66d6a5ed1d69dc (patch)
treed6514faca57dd0204170e04d6d6a94ca295fe278 /candle-examples/examples
parentc3c392f45c14f60eb4fb8397cc5c1d3891c9656d (diff)
downloadcandle-ad8a4c5e5a2f177e3b90ac812d66d6a5ed1d69dc.tar.gz
candle-ad8a4c5e5a2f177e3b90ac812d66d6a5ed1d69dc.tar.bz2
candle-ad8a4c5e5a2f177e3b90ac812d66d6a5ed1d69dc.zip
Add some llama-3.2 examples. (#2508)
* Add some llama-3.2 examples. * Support tie-word-embeddings for llama.
Diffstat (limited to 'candle-examples/examples')
-rw-r--r--candle-examples/examples/llama/main.rs14
1 files changed, 13 insertions, 1 deletions
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs
index 93f1e508..7a555b00 100644
--- a/candle-examples/examples/llama/main.rs
+++ b/candle-examples/examples/llama/main.rs
@@ -35,6 +35,10 @@ enum Which {
V31,
V3Instruct,
V31Instruct,
+ V32_1b,
+ V32_1bInstruct,
+ V32_3b,
+ V32_3bInstruct,
#[value(name = "solar-10.7b")]
Solar10_7B,
#[value(name = "tiny-llama-1.1b-chat")]
@@ -137,6 +141,10 @@ fn main() -> Result<()> {
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(),
Which::V31 => "meta-llama/Meta-Llama-3.1-8B".to_string(),
Which::V31Instruct => "meta-llama/Meta-Llama-3.1-8B-Instruct".to_string(),
+ Which::V32_1b => "meta-llama/Llama-3.2-1B".to_string(),
+ Which::V32_1bInstruct => "meta-llama/Llama-3.2-1B-Instruct".to_string(),
+ Which::V32_3b => "meta-llama/Llama-3.2-3B".to_string(),
+ Which::V32_3bInstruct => "meta-llama/Llama-3.2-3B-Instruct".to_string(),
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
});
@@ -156,10 +164,14 @@ fn main() -> Result<()> {
| Which::V3Instruct
| Which::V31
| Which::V31Instruct
+ | Which::V32_3b
+ | Which::V32_3bInstruct
| Which::Solar10_7B => {
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
}
- Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
+ Which::V32_1b | Which::V32_1bInstruct | Which::TinyLlama1_1BChat => {
+ vec![api.get("model.safetensors")?]
+ }
};
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;