diff options
Diffstat (limited to 'candle-examples/examples/quantized/main.rs')
-rw-r--r-- | candle-examples/examples/quantized/main.rs | 103 |
1 files changed, 85 insertions, 18 deletions
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index ab8a56ba..df758b4f 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -45,6 +45,10 @@ enum Which { L13bCode, #[value(name = "32b-code")] L34bCode, + #[value(name = "7b-leo")] + Leo7b, + #[value(name = "13b-leo")] + Leo13b, #[value(name = "7b-mistral")] Mistral7b, #[value(name = "7b-mistral-instruct")] @@ -55,6 +59,12 @@ enum Which { Zephyr7bBeta, #[value(name = "7b-open-chat-3.5")] OpenChat35, + #[value(name = "7b-starling-a")] + Starling7bAlpha, + #[value(name = "mixtral")] + Mixtral, + #[value(name = "mixtral-instruct")] + MixtralInstruct, } impl Which { @@ -68,12 +78,17 @@ impl Which { | Self::L70bChat | Self::L7bCode | Self::L13bCode - | Self::L34bCode => false, + | Self::L34bCode + | Self::Leo7b + | Self::Leo13b => false, // Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the - // same way. + // same way. Starling is a fine tuned version of OpenChat. Self::OpenChat35 + | Self::Starling7bAlpha | Self::Zephyr7bAlpha | Self::Zephyr7bBeta + | Self::Mixtral + | Self::MixtralInstruct | Self::Mistral7b | Self::Mistral7bInstruct => true, } @@ -90,15 +105,43 @@ impl Which { | Self::L7bCode | Self::L13bCode | Self::L34bCode + | Self::Leo7b + | Self::Leo13b + | Self::Mixtral + | Self::MixtralInstruct | Self::Mistral7b | Self::Mistral7bInstruct - | Self::OpenChat35 => false, + | Self::OpenChat35 + | Self::Starling7bAlpha => false, Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true, } } fn is_open_chat(&self) -> bool { match self { + Self::L7b + | Self::L13b + | Self::L70b + | Self::L7bChat + | Self::L13bChat + | Self::L70bChat + | Self::L7bCode + | Self::L13bCode + | Self::L34bCode + | Self::Leo7b + | Self::Leo13b + | Self::Mixtral + | Self::MixtralInstruct + | Self::Mistral7b + | Self::Mistral7bInstruct + | Self::Zephyr7bAlpha + | Self::Zephyr7bBeta => false, + Self::OpenChat35 | Self::Starling7bAlpha => true, + } + } + + fn tokenizer_repo(&self) -> &'static str { + match self { Which::L7b | Which::L13b | Which::L70b @@ -107,12 +150,17 @@ impl Which { | Which::L70bChat | Which::L7bCode | Which::L13bCode - | Which::L34bCode - | Which::Mistral7b + | Which::L34bCode => "hf-internal-testing/llama-tokenizer", + Which::Leo7b => "LeoLM/leo-hessianai-7b", + Which::Leo13b => "LeoLM/leo-hessianai-13b", + Which::Mixtral => "mistralai/Mixtral-8x7B-v0.1", + Which::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1", + Which::Mistral7b | Which::Mistral7bInstruct | Which::Zephyr7bAlpha - | Which::Zephyr7bBeta => false, - Which::OpenChat35 => true, + | Which::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1", + Which::OpenChat35 => "openchat/openchat_3.5", + Which::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha", } } } @@ -181,13 +229,7 @@ impl Args { Some(config) => std::path::PathBuf::from(config), None => { let api = hf_hub::api::sync::Api::new()?; - let repo = if self.which.is_open_chat() { - "openchat/openchat_3.5" - } else if self.which.is_mistral() { - "mistralai/Mistral-7B-v0.1" - } else { - "hf-internal-testing/llama-tokenizer" - }; + let repo = self.which.tokenizer_repo(); let api = api.model(repo.to_string()); api.get("tokenizer.json")? } @@ -218,6 +260,22 @@ impl Args { Which::L7bCode => ("TheBloke/CodeLlama-7B-GGUF", "codellama-7b.Q8_0.gguf"), Which::L13bCode => ("TheBloke/CodeLlama-13B-GGUF", "codellama-13b.Q8_0.gguf"), Which::L34bCode => ("TheBloke/CodeLlama-34B-GGUF", "codellama-34b.Q8_0.gguf"), + Which::Leo7b => ( + "TheBloke/leo-hessianai-7B-GGUF", + "leo-hessianai-7b.Q4_K_M.gguf", + ), + Which::Leo13b => ( + "TheBloke/leo-hessianai-13B-GGUF", + "leo-hessianai-13b.Q4_K_M.gguf", + ), + Which::Mixtral => ( + "TheBloke/Mixtral-8x7B-v0.1-GGUF", + "mixtral-8x7b-v0.1.Q4_K_M.gguf", + ), + Which::MixtralInstruct => ( + "TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF", + "mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf", + ), Which::Mistral7b => ( "TheBloke/Mistral-7B-v0.1-GGUF", "mistral-7b-v0.1.Q4_K_S.gguf", @@ -234,6 +292,10 @@ impl Args { ("TheBloke/zephyr-7B-beta-GGUF", "zephyr-7b-beta.Q4_K_M.gguf") } Which::OpenChat35 => ("TheBloke/openchat_3.5-GGUF", "openchat_3.5.Q4_K_M.gguf"), + Which::Starling7bAlpha => ( + "TheBloke/Starling-LM-7B-alpha-GGUF", + "starling-lm-7b-alpha.Q4_K_M.gguf", + ), }; let api = hf_hub::api::sync::Api::new()?; let api = api.model(repo.to_string()); @@ -329,14 +391,19 @@ fn main() -> anyhow::Result<()> { | Which::L13bChat | Which::L7bCode | Which::L13bCode - | Which::L34bCode => 1, - Which::Mistral7b + | Which::L34bCode + | Which::Leo7b + | Which::Leo13b => 1, + Which::Mixtral + | Which::MixtralInstruct + | Which::Mistral7b | Which::Mistral7bInstruct | Which::Zephyr7bAlpha | Which::Zephyr7bBeta | Which::L70b | Which::L70bChat - | Which::OpenChat35 => 8, + | Which::OpenChat35 + | Which::Starling7bAlpha => 8, }; ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))? } @@ -369,7 +436,7 @@ fn main() -> anyhow::Result<()> { } } if args.which.is_open_chat() { - format!("User: {prompt}<|end_of_turn|>Assistant: ") + format!("GPT4 Correct User: {prompt}<|end_of_turn|>GPT4 Correct Assistant:") } else if args.which.is_zephyr() { if prompt_index == 0 || is_interactive { format!("<|system|>\n</s>\n<|user|>\n{prompt}</s>\n<|assistant|>",) |