summaryrefslogtreecommitdiff
path: root/candle-examples/examples/quantized/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/quantized/main.rs')
-rw-r--r--candle-examples/examples/quantized/main.rs103
1 files changed, 85 insertions, 18 deletions
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs
index ab8a56ba..df758b4f 100644
--- a/candle-examples/examples/quantized/main.rs
+++ b/candle-examples/examples/quantized/main.rs
@@ -45,6 +45,10 @@ enum Which {
L13bCode,
#[value(name = "32b-code")]
L34bCode,
+ #[value(name = "7b-leo")]
+ Leo7b,
+ #[value(name = "13b-leo")]
+ Leo13b,
#[value(name = "7b-mistral")]
Mistral7b,
#[value(name = "7b-mistral-instruct")]
@@ -55,6 +59,12 @@ enum Which {
Zephyr7bBeta,
#[value(name = "7b-open-chat-3.5")]
OpenChat35,
+ #[value(name = "7b-starling-a")]
+ Starling7bAlpha,
+ #[value(name = "mixtral")]
+ Mixtral,
+ #[value(name = "mixtral-instruct")]
+ MixtralInstruct,
}
impl Which {
@@ -68,12 +78,17 @@ impl Which {
| Self::L70bChat
| Self::L7bCode
| Self::L13bCode
- | Self::L34bCode => false,
+ | Self::L34bCode
+ | Self::Leo7b
+ | Self::Leo13b => false,
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
- // same way.
+ // same way. Starling is a fine tuned version of OpenChat.
Self::OpenChat35
+ | Self::Starling7bAlpha
| Self::Zephyr7bAlpha
| Self::Zephyr7bBeta
+ | Self::Mixtral
+ | Self::MixtralInstruct
| Self::Mistral7b
| Self::Mistral7bInstruct => true,
}
@@ -90,15 +105,43 @@ impl Which {
| Self::L7bCode
| Self::L13bCode
| Self::L34bCode
+ | Self::Leo7b
+ | Self::Leo13b
+ | Self::Mixtral
+ | Self::MixtralInstruct
| Self::Mistral7b
| Self::Mistral7bInstruct
- | Self::OpenChat35 => false,
+ | Self::OpenChat35
+ | Self::Starling7bAlpha => false,
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
}
}
fn is_open_chat(&self) -> bool {
match self {
+ Self::L7b
+ | Self::L13b
+ | Self::L70b
+ | Self::L7bChat
+ | Self::L13bChat
+ | Self::L70bChat
+ | Self::L7bCode
+ | Self::L13bCode
+ | Self::L34bCode
+ | Self::Leo7b
+ | Self::Leo13b
+ | Self::Mixtral
+ | Self::MixtralInstruct
+ | Self::Mistral7b
+ | Self::Mistral7bInstruct
+ | Self::Zephyr7bAlpha
+ | Self::Zephyr7bBeta => false,
+ Self::OpenChat35 | Self::Starling7bAlpha => true,
+ }
+ }
+
+ fn tokenizer_repo(&self) -> &'static str {
+ match self {
Which::L7b
| Which::L13b
| Which::L70b
@@ -107,12 +150,17 @@ impl Which {
| Which::L70bChat
| Which::L7bCode
| Which::L13bCode
- | Which::L34bCode
- | Which::Mistral7b
+ | Which::L34bCode => "hf-internal-testing/llama-tokenizer",
+ Which::Leo7b => "LeoLM/leo-hessianai-7b",
+ Which::Leo13b => "LeoLM/leo-hessianai-13b",
+ Which::Mixtral => "mistralai/Mixtral-8x7B-v0.1",
+ Which::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1",
+ Which::Mistral7b
| Which::Mistral7bInstruct
| Which::Zephyr7bAlpha
- | Which::Zephyr7bBeta => false,
- Which::OpenChat35 => true,
+ | Which::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
+ Which::OpenChat35 => "openchat/openchat_3.5",
+ Which::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
}
}
}
@@ -181,13 +229,7 @@ impl Args {
Some(config) => std::path::PathBuf::from(config),
None => {
let api = hf_hub::api::sync::Api::new()?;
- let repo = if self.which.is_open_chat() {
- "openchat/openchat_3.5"
- } else if self.which.is_mistral() {
- "mistralai/Mistral-7B-v0.1"
- } else {
- "hf-internal-testing/llama-tokenizer"
- };
+ let repo = self.which.tokenizer_repo();
let api = api.model(repo.to_string());
api.get("tokenizer.json")?
}
@@ -218,6 +260,22 @@ impl Args {
Which::L7bCode => ("TheBloke/CodeLlama-7B-GGUF", "codellama-7b.Q8_0.gguf"),
Which::L13bCode => ("TheBloke/CodeLlama-13B-GGUF", "codellama-13b.Q8_0.gguf"),
Which::L34bCode => ("TheBloke/CodeLlama-34B-GGUF", "codellama-34b.Q8_0.gguf"),
+ Which::Leo7b => (
+ "TheBloke/leo-hessianai-7B-GGUF",
+ "leo-hessianai-7b.Q4_K_M.gguf",
+ ),
+ Which::Leo13b => (
+ "TheBloke/leo-hessianai-13B-GGUF",
+ "leo-hessianai-13b.Q4_K_M.gguf",
+ ),
+ Which::Mixtral => (
+ "TheBloke/Mixtral-8x7B-v0.1-GGUF",
+ "mixtral-8x7b-v0.1.Q4_K_M.gguf",
+ ),
+ Which::MixtralInstruct => (
+ "TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF",
+ "mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf",
+ ),
Which::Mistral7b => (
"TheBloke/Mistral-7B-v0.1-GGUF",
"mistral-7b-v0.1.Q4_K_S.gguf",
@@ -234,6 +292,10 @@ impl Args {
("TheBloke/zephyr-7B-beta-GGUF", "zephyr-7b-beta.Q4_K_M.gguf")
}
Which::OpenChat35 => ("TheBloke/openchat_3.5-GGUF", "openchat_3.5.Q4_K_M.gguf"),
+ Which::Starling7bAlpha => (
+ "TheBloke/Starling-LM-7B-alpha-GGUF",
+ "starling-lm-7b-alpha.Q4_K_M.gguf",
+ ),
};
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(repo.to_string());
@@ -329,14 +391,19 @@ fn main() -> anyhow::Result<()> {
| Which::L13bChat
| Which::L7bCode
| Which::L13bCode
- | Which::L34bCode => 1,
- Which::Mistral7b
+ | Which::L34bCode
+ | Which::Leo7b
+ | Which::Leo13b => 1,
+ Which::Mixtral
+ | Which::MixtralInstruct
+ | Which::Mistral7b
| Which::Mistral7bInstruct
| Which::Zephyr7bAlpha
| Which::Zephyr7bBeta
| Which::L70b
| Which::L70bChat
- | Which::OpenChat35 => 8,
+ | Which::OpenChat35
+ | Which::Starling7bAlpha => 8,
};
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
}
@@ -369,7 +436,7 @@ fn main() -> anyhow::Result<()> {
}
}
if args.which.is_open_chat() {
- format!("User: {prompt}<|end_of_turn|>Assistant: ")
+ format!("GPT4 Correct User: {prompt}<|end_of_turn|>GPT4 Correct Assistant:")
} else if args.which.is_zephyr() {
if prompt_index == 0 || is_interactive {
format!("<|system|>\n</s>\n<|user|>\n{prompt}</s>\n<|assistant|>",)