diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-06-07 10:51:50 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-07 10:51:50 +0100 |
commit | 54ff971e35a0fd28da062d416ffb7bc9ac9d40d8 (patch) | |
tree | c5ee47770b4f1195bc66e0bdbe75a630a4ccadbc /candle-examples | |
parent | b9fac7ec008bfccf8900552f51e6d0e865280ee9 (diff) | |
download | candle-54ff971e35a0fd28da062d416ffb7bc9ac9d40d8.tar.gz candle-54ff971e35a0fd28da062d416ffb7bc9ac9d40d8.tar.bz2 candle-54ff971e35a0fd28da062d416ffb7bc9ac9d40d8.zip |
Support for the new Qwen2 models. (#2257)
* Support for the new Qwen2 models.
* Add more models.
Diffstat (limited to 'candle-examples')
-rw-r--r-- | candle-examples/examples/qwen/main.rs | 36 |
1 files changed, 26 insertions, 10 deletions
diff --git a/candle-examples/examples/qwen/main.rs b/candle-examples/examples/qwen/main.rs index 008bada2..53f2f70d 100644 --- a/candle-examples/examples/qwen/main.rs +++ b/candle-examples/examples/qwen/main.rs @@ -144,6 +144,14 @@ enum WhichModel { W72b, #[value(name = "moe-a2.7b")] MoeA27b, + #[value(name = "2-0.5b")] + W2_0_5b, + #[value(name = "2-1.5b")] + W2_1_5b, + #[value(name = "2-7b")] + W2_7b, + #[value(name = "2-72b")] + W2_72b, } #[derive(Parser, Debug)] @@ -234,16 +242,20 @@ fn main() -> Result<()> { let model_id = match args.model_id { Some(model_id) => model_id, None => { - let size = match args.model { - WhichModel::W0_5b => "0.5B", - WhichModel::W1_8b => "1.8B", - WhichModel::W4b => "4B", - WhichModel::W7b => "7B", - WhichModel::W14b => "14B", - WhichModel::W72b => "72B", - WhichModel::MoeA27b => "MoE-A2.7B", + let (version, size) = match args.model { + WhichModel::W2_0_5b => ("2", "0.5B"), + WhichModel::W2_1_5b => ("2", "1.5B"), + WhichModel::W2_7b => ("2", "7B"), + WhichModel::W2_72b => ("2", "72B"), + WhichModel::W0_5b => ("1.5", "0.5B"), + WhichModel::W1_8b => ("1.5", "1.8B"), + WhichModel::W4b => ("1.5", "4B"), + WhichModel::W7b => ("1.5", "7B"), + WhichModel::W14b => ("1.5", "14B"), + WhichModel::W72b => ("1.5", "72B"), + WhichModel::MoeA27b => ("1.5", "MoE-A2.7B"), }; - format!("Qwen/Qwen1.5-{size}") + format!("Qwen/Qwen{version}-{size}") } }; let repo = api.repo(Repo::with_revision( @@ -261,11 +273,15 @@ fn main() -> Result<()> { .map(std::path::PathBuf::from) .collect::<Vec<_>>(), None => match args.model { - WhichModel::W0_5b | WhichModel::W1_8b => vec![repo.get("model.safetensors")?], + WhichModel::W0_5b | WhichModel::W2_0_5b | WhichModel::W2_1_5b | WhichModel::W1_8b => { + vec![repo.get("model.safetensors")?] + } WhichModel::W4b | WhichModel::W7b + | WhichModel::W2_7b | WhichModel::W14b | WhichModel::W72b + | WhichModel::W2_72b | WhichModel::MoeA27b => { candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")? } |