diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-01 06:44:30 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-01 06:44:30 +0100 |
commit | f6054e9d60ef15add8a9a20b0aae8db630383d8f (patch) | |
tree | 6376332c54e0b2a72c6b0eaecd048ecddb32cbe9 | |
parent | 328167ec04bec4536b4ab5581685ebdf918211ee (diff) | |
download | candle-f6054e9d60ef15add8a9a20b0aae8db630383d8f.tar.gz candle-f6054e9d60ef15add8a9a20b0aae8db630383d8f.tar.bz2 candle-f6054e9d60ef15add8a9a20b0aae8db630383d8f.zip |
Fix the prompt for mistral when using instruct/interactive mode. (#1013)
-rw-r--r-- | candle-examples/examples/quantized/main.rs | 43 |
1 files changed, 31 insertions, 12 deletions
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index 3e663851..bb613dc7 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -50,6 +50,23 @@ enum Which { Mistral7bInstruct, } +impl Which { + fn is_mistral(&self) -> bool { + match self { + Self::L7b + | Self::L13b + | Self::L70b + | Self::L7bChat + | Self::L13bChat + | Self::L70bChat + | Self::L7bCode + | Self::L13bCode + | Self::L34bCode => false, + Self::Mistral7b | Self::Mistral7bInstruct => true, + } + } +} + #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { @@ -114,17 +131,10 @@ impl Args { Some(config) => std::path::PathBuf::from(config), None => { let api = hf_hub::api::sync::Api::new()?; - let repo = match self.which { - Which::L7b - | Which::L13b - | Which::L70b - | Which::L7bCode - | Which::L13bCode - | Which::L34bCode - | Which::L7bChat - | Which::L13bChat - | Which::L70bChat => "hf-internal-testing/llama-tokenizer", - Which::Mistral7b | Which::Mistral7bInstruct => "mistralai/Mistral-7B-v0.1", + let repo = if self.which.is_mistral() { + "mistralai/Mistral-7B-v0.1" + } else { + "hf-internal-testing/llama-tokenizer" }; let api = api.model(repo.to_string()); api.get("tokenizer.json")? @@ -315,7 +325,11 @@ fn main() -> anyhow::Result<()> { prompt.pop(); } } - prompt + if args.which.is_mistral() { + format!("[INST] {prompt} [/INST]") + } else { + prompt + } } }; print!("{}", &prompt_str); @@ -351,6 +365,8 @@ fn main() -> anyhow::Result<()> { all_tokens.push(next_token); print_token(next_token, &tokenizer); + let eos_token = *tokenizer.get_vocab(true).get("</s>").unwrap(); + let start_post_prompt = std::time::Instant::now(); for index in 0..to_sample { let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?; @@ -369,6 +385,9 @@ fn main() -> anyhow::Result<()> { next_token = logits_processor.sample(&logits)?; all_tokens.push(next_token); print_token(next_token, &tokenizer); + if next_token == eos_token { + break; + }; } let dt = start_post_prompt.elapsed(); println!( |