summaryrefslogtreecommitdiff
path: root/candle-examples
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-12-03 13:30:41 +0100
committerGitHub <noreply@github.com>2023-12-03 12:30:41 +0000
commit16161145ae54c8f9692c06ea9931fa5d1fac4873 (patch)
tree6e535f63cd2c62205d96b9c8d7e2359548ba35b3 /candle-examples
parent0738df5290e5c18d489557e9e02cde4acbbb2249 (diff)
downloadcandle-16161145ae54c8f9692c06ea9931fa5d1fac4873.tar.gz
candle-16161145ae54c8f9692c06ea9931fa5d1fac4873.tar.bz2
candle-16161145ae54c8f9692c06ea9931fa5d1fac4873.zip
Add the leo models to the quantized examples. (#1398)
Diffstat (limited to 'candle-examples')
-rw-r--r--candle-examples/examples/quantized/main.rs77
1 files changed, 46 insertions, 31 deletions
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs
index 18db3f9a..b21d6751 100644
--- a/candle-examples/examples/quantized/main.rs
+++ b/candle-examples/examples/quantized/main.rs
@@ -45,6 +45,10 @@ enum Which {
L13bCode,
#[value(name = "32b-code")]
L34bCode,
+ #[value(name = "7b-leo")]
+ Leo7b,
+ #[value(name = "13b-leo")]
+ Leo13b,
#[value(name = "7b-mistral")]
Mistral7b,
#[value(name = "7b-mistral-instruct")]
@@ -70,7 +74,9 @@ impl Which {
| Self::L70bChat
| Self::L7bCode
| Self::L13bCode
- | Self::L34bCode => false,
+ | Self::L34bCode
+ | Self::Leo7b
+ | Self::Leo13b => false,
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
// same way. Starling is a fine tuned version of OpenChat.
Self::OpenChat35
@@ -93,6 +99,8 @@ impl Which {
| Self::L7bCode
| Self::L13bCode
| Self::L34bCode
+ | Self::Leo7b
+ | Self::Leo13b
| Self::Mistral7b
| Self::Mistral7bInstruct
| Self::OpenChat35
@@ -103,23 +111,26 @@ impl Which {
fn is_open_chat(&self) -> bool {
match self {
- Which::L7b
- | Which::L13b
- | Which::L70b
- | Which::L7bChat
- | Which::L13bChat
- | Which::L70bChat
- | Which::L7bCode
- | Which::L13bCode
- | Which::L34bCode
- | Which::Mistral7b
- | Which::Mistral7bInstruct
- | Which::Zephyr7bAlpha
- | Which::Zephyr7bBeta => false,
- Which::OpenChat35 | Self::Starling7bAlpha => true,
+ Self::L7b
+ | Self::L13b
+ | Self::L70b
+ | Self::L7bChat
+ | Self::L13bChat
+ | Self::L70bChat
+ | Self::L7bCode
+ | Self::L13bCode
+ | Self::L34bCode
+ | Self::Leo7b
+ | Self::Leo13b
+ | Self::Mistral7b
+ | Self::Mistral7bInstruct
+ | Self::Zephyr7bAlpha
+ | Self::Zephyr7bBeta => false,
+ Self::OpenChat35 | Self::Starling7bAlpha => true,
}
}
- fn is_starling(&self) -> bool {
+
+ fn tokenizer_repo(&self) -> &'static str {
match self {
Which::L7b
| Which::L13b
@@ -129,13 +140,15 @@ impl Which {
| Which::L70bChat
| Which::L7bCode
| Which::L13bCode
- | Which::L34bCode
- | Which::Mistral7b
+ | Which::L34bCode => "hf-internal-testing/llama-tokenizer",
+ Which::Leo7b => "LeoLM/leo-hessianai-7b",
+ Which::Leo13b => "LeoLM/leo-hessianai-13b",
+ Which::Mistral7b
| Which::Mistral7bInstruct
| Which::Zephyr7bAlpha
- | Which::Zephyr7bBeta
- | Which::OpenChat35 => false,
- Which::Starling7bAlpha => true,
+ | Which::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
+ Which::OpenChat35 => "openchat/openchat_3.5",
+ Which::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
}
}
}
@@ -204,15 +217,7 @@ impl Args {
Some(config) => std::path::PathBuf::from(config),
None => {
let api = hf_hub::api::sync::Api::new()?;
- let repo = if self.which.is_starling() {
- "berkeley-nest/Starling-LM-7B-alpha"
- } else if self.which.is_open_chat() {
- "openchat/openchat_3.5"
- } else if self.which.is_mistral() {
- "mistralai/Mistral-7B-v0.1"
- } else {
- "hf-internal-testing/llama-tokenizer"
- };
+ let repo = self.which.tokenizer_repo();
let api = api.model(repo.to_string());
api.get("tokenizer.json")?
}
@@ -243,6 +248,14 @@ impl Args {
Which::L7bCode => ("TheBloke/CodeLlama-7B-GGUF", "codellama-7b.Q8_0.gguf"),
Which::L13bCode => ("TheBloke/CodeLlama-13B-GGUF", "codellama-13b.Q8_0.gguf"),
Which::L34bCode => ("TheBloke/CodeLlama-34B-GGUF", "codellama-34b.Q8_0.gguf"),
+ Which::Leo7b => (
+ "TheBloke/leo-hessianai-7B-GGUF",
+ "leo-hessianai-7b.Q4_K_M.gguf",
+ ),
+ Which::Leo13b => (
+ "TheBloke/leo-hessianai-13B-GGUF",
+ "leo-hessianai-13b.Q4_K_M.gguf",
+ ),
Which::Mistral7b => (
"TheBloke/Mistral-7B-v0.1-GGUF",
"mistral-7b-v0.1.Q4_K_S.gguf",
@@ -358,7 +371,9 @@ fn main() -> anyhow::Result<()> {
| Which::L13bChat
| Which::L7bCode
| Which::L13bCode
- | Which::L34bCode => 1,
+ | Which::L34bCode
+ | Which::Leo7b
+ | Which::Leo13b => 1,
Which::Mistral7b
| Which::Mistral7bInstruct
| Which::Zephyr7bAlpha