summaryrefslogtreecommitdiff
path: root/candle-examples
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-06-07 10:51:50 +0100
committerGitHub <noreply@github.com>2024-06-07 10:51:50 +0100
commit54ff971e35a0fd28da062d416ffb7bc9ac9d40d8 (patch)
treec5ee47770b4f1195bc66e0bdbe75a630a4ccadbc /candle-examples
parentb9fac7ec008bfccf8900552f51e6d0e865280ee9 (diff)
downloadcandle-54ff971e35a0fd28da062d416ffb7bc9ac9d40d8.tar.gz
candle-54ff971e35a0fd28da062d416ffb7bc9ac9d40d8.tar.bz2
candle-54ff971e35a0fd28da062d416ffb7bc9ac9d40d8.zip
Support for the new Qwen2 models. (#2257)
* Support for the new Qwen2 models. * Add more models.
Diffstat (limited to 'candle-examples')
-rw-r--r--candle-examples/examples/qwen/main.rs36
1 files changed, 26 insertions, 10 deletions
diff --git a/candle-examples/examples/qwen/main.rs b/candle-examples/examples/qwen/main.rs
index 008bada2..53f2f70d 100644
--- a/candle-examples/examples/qwen/main.rs
+++ b/candle-examples/examples/qwen/main.rs
@@ -144,6 +144,14 @@ enum WhichModel {
W72b,
#[value(name = "moe-a2.7b")]
MoeA27b,
+ #[value(name = "2-0.5b")]
+ W2_0_5b,
+ #[value(name = "2-1.5b")]
+ W2_1_5b,
+ #[value(name = "2-7b")]
+ W2_7b,
+ #[value(name = "2-72b")]
+ W2_72b,
}
#[derive(Parser, Debug)]
@@ -234,16 +242,20 @@ fn main() -> Result<()> {
let model_id = match args.model_id {
Some(model_id) => model_id,
None => {
- let size = match args.model {
- WhichModel::W0_5b => "0.5B",
- WhichModel::W1_8b => "1.8B",
- WhichModel::W4b => "4B",
- WhichModel::W7b => "7B",
- WhichModel::W14b => "14B",
- WhichModel::W72b => "72B",
- WhichModel::MoeA27b => "MoE-A2.7B",
+ let (version, size) = match args.model {
+ WhichModel::W2_0_5b => ("2", "0.5B"),
+ WhichModel::W2_1_5b => ("2", "1.5B"),
+ WhichModel::W2_7b => ("2", "7B"),
+ WhichModel::W2_72b => ("2", "72B"),
+ WhichModel::W0_5b => ("1.5", "0.5B"),
+ WhichModel::W1_8b => ("1.5", "1.8B"),
+ WhichModel::W4b => ("1.5", "4B"),
+ WhichModel::W7b => ("1.5", "7B"),
+ WhichModel::W14b => ("1.5", "14B"),
+ WhichModel::W72b => ("1.5", "72B"),
+ WhichModel::MoeA27b => ("1.5", "MoE-A2.7B"),
};
- format!("Qwen/Qwen1.5-{size}")
+ format!("Qwen/Qwen{version}-{size}")
}
};
let repo = api.repo(Repo::with_revision(
@@ -261,11 +273,15 @@ fn main() -> Result<()> {
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => match args.model {
- WhichModel::W0_5b | WhichModel::W1_8b => vec![repo.get("model.safetensors")?],
+ WhichModel::W0_5b | WhichModel::W2_0_5b | WhichModel::W2_1_5b | WhichModel::W1_8b => {
+ vec![repo.get("model.safetensors")?]
+ }
WhichModel::W4b
| WhichModel::W7b
+ | WhichModel::W2_7b
| WhichModel::W14b
| WhichModel::W72b
+ | WhichModel::W2_72b
| WhichModel::MoeA27b => {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}