summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/quantized/main.rs42
1 files changed, 36 insertions, 6 deletions
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs
index ab8a56ba..18db3f9a 100644
--- a/candle-examples/examples/quantized/main.rs
+++ b/candle-examples/examples/quantized/main.rs
@@ -55,6 +55,8 @@ enum Which {
Zephyr7bBeta,
#[value(name = "7b-open-chat-3.5")]
OpenChat35,
+ #[value(name = "7b-starling-a")]
+ Starling7bAlpha,
}
impl Which {
@@ -70,8 +72,9 @@ impl Which {
| Self::L13bCode
| Self::L34bCode => false,
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
- // same way.
+ // same way. Starling is a fine tuned version of OpenChat.
Self::OpenChat35
+ | Self::Starling7bAlpha
| Self::Zephyr7bAlpha
| Self::Zephyr7bBeta
| Self::Mistral7b
@@ -92,7 +95,8 @@ impl Which {
| Self::L34bCode
| Self::Mistral7b
| Self::Mistral7bInstruct
- | Self::OpenChat35 => false,
+ | Self::OpenChat35
+ | Self::Starling7bAlpha => false,
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
}
}
@@ -112,7 +116,26 @@ impl Which {
| Which::Mistral7bInstruct
| Which::Zephyr7bAlpha
| Which::Zephyr7bBeta => false,
- Which::OpenChat35 => true,
+ Which::OpenChat35 | Self::Starling7bAlpha => true,
+ }
+ }
+ fn is_starling(&self) -> bool {
+ match self {
+ Which::L7b
+ | Which::L13b
+ | Which::L70b
+ | Which::L7bChat
+ | Which::L13bChat
+ | Which::L70bChat
+ | Which::L7bCode
+ | Which::L13bCode
+ | Which::L34bCode
+ | Which::Mistral7b
+ | Which::Mistral7bInstruct
+ | Which::Zephyr7bAlpha
+ | Which::Zephyr7bBeta
+ | Which::OpenChat35 => false,
+ Which::Starling7bAlpha => true,
}
}
}
@@ -181,7 +204,9 @@ impl Args {
Some(config) => std::path::PathBuf::from(config),
None => {
let api = hf_hub::api::sync::Api::new()?;
- let repo = if self.which.is_open_chat() {
+ let repo = if self.which.is_starling() {
+ "berkeley-nest/Starling-LM-7B-alpha"
+ } else if self.which.is_open_chat() {
"openchat/openchat_3.5"
} else if self.which.is_mistral() {
"mistralai/Mistral-7B-v0.1"
@@ -234,6 +259,10 @@ impl Args {
("TheBloke/zephyr-7B-beta-GGUF", "zephyr-7b-beta.Q4_K_M.gguf")
}
Which::OpenChat35 => ("TheBloke/openchat_3.5-GGUF", "openchat_3.5.Q4_K_M.gguf"),
+ Which::Starling7bAlpha => (
+ "TheBloke/Starling-LM-7B-alpha-GGUF",
+ "starling-lm-7b-alpha.Q4_K_M.gguf",
+ ),
};
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(repo.to_string());
@@ -336,7 +365,8 @@ fn main() -> anyhow::Result<()> {
| Which::Zephyr7bBeta
| Which::L70b
| Which::L70bChat
- | Which::OpenChat35 => 8,
+ | Which::OpenChat35
+ | Which::Starling7bAlpha => 8,
};
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
}
@@ -369,7 +399,7 @@ fn main() -> anyhow::Result<()> {
}
}
if args.which.is_open_chat() {
- format!("User: {prompt}<|end_of_turn|>Assistant: ")
+ format!("GPT4 Correct User: {prompt}<|end_of_turn|>GPT4 Correct Assistant:")
} else if args.which.is_zephyr() {
if prompt_index == 0 || is_interactive {
format!("<|system|>\n</s>\n<|user|>\n{prompt}</s>\n<|assistant|>",)