diff options
Diffstat (limited to 'candle-examples/examples/llama/main.rs')
-rw-r--r-- | candle-examples/examples/llama/main.rs | 30 |
1 files changed, 24 insertions, 6 deletions
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index fa7ce81b..93f1e508 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -32,7 +32,9 @@ enum Which { V1, V2, V3, + V31, V3Instruct, + V31Instruct, #[value(name = "solar-10.7b")] Solar10_7B, #[value(name = "tiny-llama-1.1b-chat")] @@ -133,6 +135,8 @@ fn main() -> Result<()> { Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(), Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(), Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(), + Which::V31 => "meta-llama/Meta-Llama-3.1-8B".to_string(), + Which::V31Instruct => "meta-llama/Meta-Llama-3.1-8B-Instruct".to_string(), Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(), Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(), }); @@ -146,7 +150,13 @@ fn main() -> Result<()> { let config = config.into_config(args.use_flash_attn); let filenames = match args.which { - Which::V1 | Which::V2 | Which::V3 | Which::V3Instruct | Which::Solar10_7B => { + Which::V1 + | Which::V2 + | Which::V3 + | Which::V3Instruct + | Which::V31 + | Which::V31Instruct + | Which::Solar10_7B => { candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")? } Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?], @@ -157,9 +167,11 @@ fn main() -> Result<()> { (Llama::load(vb, &config)?, tokenizer_filename, cache, config) }; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; - let eos_token_id = config - .eos_token_id - .or_else(|| tokenizer.token_to_id(EOS_TOKEN)); + let eos_token_id = config.eos_token_id.or_else(|| { + tokenizer + .token_to_id(EOS_TOKEN) + .map(model::LlamaEosToks::Single) + }); let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str()); let mut tokens = tokenizer .encode(prompt, true) @@ -217,8 +229,14 @@ fn main() -> Result<()> { token_generated += 1; tokens.push(next_token); - if Some(next_token) == eos_token_id { - break; + match eos_token_id { + Some(model::LlamaEosToks::Single(eos_tok_id)) if next_token == eos_tok_id => { + break; + } + Some(model::LlamaEosToks::Multiple(ref eos_ids)) if eos_ids.contains(&next_token) => { + break; + } + _ => (), } if let Some(t) = tokenizer.next_token(next_token)? { print!("{t}"); |