summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/llama/main.rs')
-rw-r--r--candle-examples/examples/llama/main.rs30
1 files changed, 24 insertions, 6 deletions
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs
index fa7ce81b..93f1e508 100644
--- a/candle-examples/examples/llama/main.rs
+++ b/candle-examples/examples/llama/main.rs
@@ -32,7 +32,9 @@ enum Which {
V1,
V2,
V3,
+ V31,
V3Instruct,
+ V31Instruct,
#[value(name = "solar-10.7b")]
Solar10_7B,
#[value(name = "tiny-llama-1.1b-chat")]
@@ -133,6 +135,8 @@ fn main() -> Result<()> {
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(),
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(),
+ Which::V31 => "meta-llama/Meta-Llama-3.1-8B".to_string(),
+ Which::V31Instruct => "meta-llama/Meta-Llama-3.1-8B-Instruct".to_string(),
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
});
@@ -146,7 +150,13 @@ fn main() -> Result<()> {
let config = config.into_config(args.use_flash_attn);
let filenames = match args.which {
- Which::V1 | Which::V2 | Which::V3 | Which::V3Instruct | Which::Solar10_7B => {
+ Which::V1
+ | Which::V2
+ | Which::V3
+ | Which::V3Instruct
+ | Which::V31
+ | Which::V31Instruct
+ | Which::Solar10_7B => {
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
}
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
@@ -157,9 +167,11 @@ fn main() -> Result<()> {
(Llama::load(vb, &config)?, tokenizer_filename, cache, config)
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
- let eos_token_id = config
- .eos_token_id
- .or_else(|| tokenizer.token_to_id(EOS_TOKEN));
+ let eos_token_id = config.eos_token_id.or_else(|| {
+ tokenizer
+ .token_to_id(EOS_TOKEN)
+ .map(model::LlamaEosToks::Single)
+ });
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
let mut tokens = tokenizer
.encode(prompt, true)
@@ -217,8 +229,14 @@ fn main() -> Result<()> {
token_generated += 1;
tokens.push(next_token);
- if Some(next_token) == eos_token_id {
- break;
+ match eos_token_id {
+ Some(model::LlamaEosToks::Single(eos_tok_id)) if next_token == eos_tok_id => {
+ break;
+ }
+ Some(model::LlamaEosToks::Multiple(ref eos_ids)) if eos_ids.contains(&next_token) => {
+ break;
+ }
+ _ => (),
}
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");