summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/quantized/main.rs9
1 files changed, 7 insertions, 2 deletions
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs
index 12b4b059..068ae12d 100644
--- a/candle-examples/examples/quantized/main.rs
+++ b/candle-examples/examples/quantized/main.rs
@@ -325,10 +325,11 @@ fn main() -> anyhow::Result<()> {
};
let mut pre_prompt_tokens = vec![];
- loop {
+ for prompt_index in 0.. {
let prompt_str = match &prompt {
Prompt::One(prompt) => prompt.clone(),
Prompt::Interactive | Prompt::Chat => {
+ let is_interactive = matches!(prompt, Prompt::Interactive);
print!("> ");
std::io::stdout().flush()?;
let mut prompt = String::new();
@@ -340,7 +341,11 @@ fn main() -> anyhow::Result<()> {
}
}
if args.which.is_zephyr() {
- format!("<|system|>\n</s>\n<|user|>\n{prompt}</s>\n<|assistant|>")
+ if prompt_index == 0 || is_interactive {
+ format!("<|system|>\n</s>\n<|user|>\n{prompt}</s>\n<|assistant|>",)
+ } else {
+ format!("<|user|>\n{prompt}</s>\n<|assistant|>")
+ }
} else if args.which.is_mistral() {
format!("[INST] {prompt} [/INST]")
} else {