diff options
Diffstat (limited to 'candle-examples/examples/quantized/main.rs')
-rw-r--r-- | candle-examples/examples/quantized/main.rs | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index 2e9e7f07..15926e0a 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -24,7 +24,7 @@ enum Prompt { One(String), } -#[derive(Clone, Debug, Copy, ValueEnum)] +#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] enum Which { #[value(name = "7b")] L7b, @@ -64,7 +64,8 @@ impl Which { | Self::L7bCode | Self::L13bCode | Self::L34bCode => false, - Self::Mistral7b | Self::Mistral7bInstruct | Self::Zephyr7b => true, + // Zephyr is a fine tuned version of mistral and should be treated in the same way. + Self::Zephyr7b | Self::Mistral7b | Self::Mistral7bInstruct => true, } } } @@ -335,7 +336,9 @@ fn main() -> anyhow::Result<()> { prompt.pop(); } } - if args.which.is_mistral() { + if args.which == Which::Zephyr7b { + format!("<|system|>\n</s>\n<|user|>\n{prompt}</s>\n<|assistant|>") + } else if args.which.is_mistral() { format!("[INST] {prompt} [/INST]") } else { prompt |