summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-01 06:44:30 +0100
committerGitHub <noreply@github.com>2023-10-01 06:44:30 +0100
commitf6054e9d60ef15add8a9a20b0aae8db630383d8f (patch)
tree6376332c54e0b2a72c6b0eaecd048ecddb32cbe9
parent328167ec04bec4536b4ab5581685ebdf918211ee (diff)
downloadcandle-f6054e9d60ef15add8a9a20b0aae8db630383d8f.tar.gz
candle-f6054e9d60ef15add8a9a20b0aae8db630383d8f.tar.bz2
candle-f6054e9d60ef15add8a9a20b0aae8db630383d8f.zip
Fix the prompt for mistral when using instruct/interactive mode. (#1013)
-rw-r--r--candle-examples/examples/quantized/main.rs43
1 files changed, 31 insertions, 12 deletions
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs
index 3e663851..bb613dc7 100644
--- a/candle-examples/examples/quantized/main.rs
+++ b/candle-examples/examples/quantized/main.rs
@@ -50,6 +50,23 @@ enum Which {
Mistral7bInstruct,
}
+impl Which {
+ fn is_mistral(&self) -> bool {
+ match self {
+ Self::L7b
+ | Self::L13b
+ | Self::L70b
+ | Self::L7bChat
+ | Self::L13bChat
+ | Self::L70bChat
+ | Self::L7bCode
+ | Self::L13bCode
+ | Self::L34bCode => false,
+ Self::Mistral7b | Self::Mistral7bInstruct => true,
+ }
+ }
+}
+
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
@@ -114,17 +131,10 @@ impl Args {
Some(config) => std::path::PathBuf::from(config),
None => {
let api = hf_hub::api::sync::Api::new()?;
- let repo = match self.which {
- Which::L7b
- | Which::L13b
- | Which::L70b
- | Which::L7bCode
- | Which::L13bCode
- | Which::L34bCode
- | Which::L7bChat
- | Which::L13bChat
- | Which::L70bChat => "hf-internal-testing/llama-tokenizer",
- Which::Mistral7b | Which::Mistral7bInstruct => "mistralai/Mistral-7B-v0.1",
+ let repo = if self.which.is_mistral() {
+ "mistralai/Mistral-7B-v0.1"
+ } else {
+ "hf-internal-testing/llama-tokenizer"
};
let api = api.model(repo.to_string());
api.get("tokenizer.json")?
@@ -315,7 +325,11 @@ fn main() -> anyhow::Result<()> {
prompt.pop();
}
}
- prompt
+ if args.which.is_mistral() {
+ format!("[INST] {prompt} [/INST]")
+ } else {
+ prompt
+ }
}
};
print!("{}", &prompt_str);
@@ -351,6 +365,8 @@ fn main() -> anyhow::Result<()> {
all_tokens.push(next_token);
print_token(next_token, &tokenizer);
+ let eos_token = *tokenizer.get_vocab(true).get("</s>").unwrap();
+
let start_post_prompt = std::time::Instant::now();
for index in 0..to_sample {
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
@@ -369,6 +385,9 @@ fn main() -> anyhow::Result<()> {
next_token = logits_processor.sample(&logits)?;
all_tokens.push(next_token);
print_token(next_token, &tokenizer);
+ if next_token == eos_token {
+ break;
+ };
}
let dt = start_post_prompt.elapsed();
println!(