summaryrefslogtreecommitdiff
path: root/candle-examples/examples/qwen/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/qwen/main.rs')
-rw-r--r--candle-examples/examples/qwen/main.rs38
1 files changed, 34 insertions, 4 deletions
diff --git a/candle-examples/examples/qwen/main.rs b/candle-examples/examples/qwen/main.rs
index d040d4b0..a203ad8e 100644
--- a/candle-examples/examples/qwen/main.rs
+++ b/candle-examples/examples/qwen/main.rs
@@ -7,7 +7,8 @@ extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
-use candle_transformers::models::qwen2::{Config, Model};
+use candle_transformers::models::qwen2::{Config as ConfigBase, Model as ModelBase};
+use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
@@ -16,6 +17,20 @@ use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
+enum Model {
+ Base(ModelBase),
+ Moe(ModelMoe),
+}
+
+impl Model {
+ fn forward(&mut self, xs: &Tensor, s: usize) -> candle::Result<Tensor> {
+ match self {
+ Self::Moe(ref mut m) => m.forward(xs, s),
+ Self::Base(ref mut m) => m.forward(xs, s),
+ }
+ }
+}
+
struct TextGeneration {
model: Model,
device: Device,
@@ -127,6 +142,8 @@ enum WhichModel {
W14b,
#[value(name = "72b")]
W72b,
+ #[value(name = "moe-a2.7b")]
+ MoeA27b,
}
#[derive(Parser, Debug)]
@@ -224,6 +241,7 @@ fn main() -> Result<()> {
WhichModel::W7b => "7B",
WhichModel::W14b => "14B",
WhichModel::W72b => "72B",
+ WhichModel::MoeA27b => "MoE-A2.7B",
};
format!("Qwen/Qwen1.5-{size}")
}
@@ -244,7 +262,11 @@ fn main() -> Result<()> {
.collect::<Vec<_>>(),
None => match args.model {
WhichModel::W0_5b | WhichModel::W1_8b => vec![repo.get("model.safetensors")?],
- WhichModel::W4b | WhichModel::W7b | WhichModel::W14b | WhichModel::W72b => {
+ WhichModel::W4b
+ | WhichModel::W7b
+ | WhichModel::W14b
+ | WhichModel::W72b
+ | WhichModel::MoeA27b => {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}
},
@@ -254,7 +276,6 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let config_file = repo.get("config.json")?;
- let config: Config = serde_json::from_slice(&std::fs::read(config_file)?)?;
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
@@ -262,7 +283,16 @@ fn main() -> Result<()> {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
- let model = Model::new(&config, vb)?;
+ let model = match args.model {
+ WhichModel::MoeA27b => {
+ let config: ConfigMoe = serde_json::from_slice(&std::fs::read(config_file)?)?;
+ Model::Moe(ModelMoe::new(&config, vb)?)
+ }
+ _ => {
+ let config: ConfigBase = serde_json::from_slice(&std::fs::read(config_file)?)?;
+ Model::Base(ModelBase::new(&config, vb)?)
+ }
+ };
println!("loaded the model in {:?}", start.elapsed());