diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-12-03 13:30:41 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-03 12:30:41 +0000 |
commit | 16161145ae54c8f9692c06ea9931fa5d1fac4873 (patch) | |
tree | 6e535f63cd2c62205d96b9c8d7e2359548ba35b3 /candle-examples | |
parent | 0738df5290e5c18d489557e9e02cde4acbbb2249 (diff) | |
download | candle-16161145ae54c8f9692c06ea9931fa5d1fac4873.tar.gz candle-16161145ae54c8f9692c06ea9931fa5d1fac4873.tar.bz2 candle-16161145ae54c8f9692c06ea9931fa5d1fac4873.zip |
Add the leo models to the quantized examples. (#1398)
Diffstat (limited to 'candle-examples')
-rw-r--r-- | candle-examples/examples/quantized/main.rs | 77 |
1 files changed, 46 insertions, 31 deletions
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index 18db3f9a..b21d6751 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -45,6 +45,10 @@ enum Which { L13bCode, #[value(name = "32b-code")] L34bCode, + #[value(name = "7b-leo")] + Leo7b, + #[value(name = "13b-leo")] + Leo13b, #[value(name = "7b-mistral")] Mistral7b, #[value(name = "7b-mistral-instruct")] @@ -70,7 +74,9 @@ impl Which { | Self::L70bChat | Self::L7bCode | Self::L13bCode - | Self::L34bCode => false, + | Self::L34bCode + | Self::Leo7b + | Self::Leo13b => false, // Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the // same way. Starling is a fine tuned version of OpenChat. Self::OpenChat35 @@ -93,6 +99,8 @@ impl Which { | Self::L7bCode | Self::L13bCode | Self::L34bCode + | Self::Leo7b + | Self::Leo13b | Self::Mistral7b | Self::Mistral7bInstruct | Self::OpenChat35 @@ -103,23 +111,26 @@ impl Which { fn is_open_chat(&self) -> bool { match self { - Which::L7b - | Which::L13b - | Which::L70b - | Which::L7bChat - | Which::L13bChat - | Which::L70bChat - | Which::L7bCode - | Which::L13bCode - | Which::L34bCode - | Which::Mistral7b - | Which::Mistral7bInstruct - | Which::Zephyr7bAlpha - | Which::Zephyr7bBeta => false, - Which::OpenChat35 | Self::Starling7bAlpha => true, + Self::L7b + | Self::L13b + | Self::L70b + | Self::L7bChat + | Self::L13bChat + | Self::L70bChat + | Self::L7bCode + | Self::L13bCode + | Self::L34bCode + | Self::Leo7b + | Self::Leo13b + | Self::Mistral7b + | Self::Mistral7bInstruct + | Self::Zephyr7bAlpha + | Self::Zephyr7bBeta => false, + Self::OpenChat35 | Self::Starling7bAlpha => true, } } - fn is_starling(&self) -> bool { + + fn tokenizer_repo(&self) -> &'static str { match self { Which::L7b | Which::L13b @@ -129,13 +140,15 @@ impl Which { | Which::L70bChat | Which::L7bCode | Which::L13bCode - | Which::L34bCode - | Which::Mistral7b + | Which::L34bCode => "hf-internal-testing/llama-tokenizer", + Which::Leo7b => "LeoLM/leo-hessianai-7b", + Which::Leo13b => "LeoLM/leo-hessianai-13b", + Which::Mistral7b | Which::Mistral7bInstruct | Which::Zephyr7bAlpha - | Which::Zephyr7bBeta - | Which::OpenChat35 => false, - Which::Starling7bAlpha => true, + | Which::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1", + Which::OpenChat35 => "openchat/openchat_3.5", + Which::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha", } } } @@ -204,15 +217,7 @@ impl Args { Some(config) => std::path::PathBuf::from(config), None => { let api = hf_hub::api::sync::Api::new()?; - let repo = if self.which.is_starling() { - "berkeley-nest/Starling-LM-7B-alpha" - } else if self.which.is_open_chat() { - "openchat/openchat_3.5" - } else if self.which.is_mistral() { - "mistralai/Mistral-7B-v0.1" - } else { - "hf-internal-testing/llama-tokenizer" - }; + let repo = self.which.tokenizer_repo(); let api = api.model(repo.to_string()); api.get("tokenizer.json")? } @@ -243,6 +248,14 @@ impl Args { Which::L7bCode => ("TheBloke/CodeLlama-7B-GGUF", "codellama-7b.Q8_0.gguf"), Which::L13bCode => ("TheBloke/CodeLlama-13B-GGUF", "codellama-13b.Q8_0.gguf"), Which::L34bCode => ("TheBloke/CodeLlama-34B-GGUF", "codellama-34b.Q8_0.gguf"), + Which::Leo7b => ( + "TheBloke/leo-hessianai-7B-GGUF", + "leo-hessianai-7b.Q4_K_M.gguf", + ), + Which::Leo13b => ( + "TheBloke/leo-hessianai-13B-GGUF", + "leo-hessianai-13b.Q4_K_M.gguf", + ), Which::Mistral7b => ( "TheBloke/Mistral-7B-v0.1-GGUF", "mistral-7b-v0.1.Q4_K_S.gguf", @@ -358,7 +371,9 @@ fn main() -> anyhow::Result<()> { | Which::L13bChat | Which::L7bCode | Which::L13bCode - | Which::L34bCode => 1, + | Which::L34bCode + | Which::Leo7b + | Which::Leo13b => 1, Which::Mistral7b | Which::Mistral7bInstruct | Which::Zephyr7bAlpha |