diff options
Diffstat (limited to 'candle-examples/examples/qwen/main.rs')
-rw-r--r-- | candle-examples/examples/qwen/main.rs | 38 |
1 files changed, 34 insertions, 4 deletions
diff --git a/candle-examples/examples/qwen/main.rs b/candle-examples/examples/qwen/main.rs index d040d4b0..a203ad8e 100644 --- a/candle-examples/examples/qwen/main.rs +++ b/candle-examples/examples/qwen/main.rs @@ -7,7 +7,8 @@ extern crate accelerate_src; use anyhow::{Error as E, Result}; use clap::Parser; -use candle_transformers::models::qwen2::{Config, Model}; +use candle_transformers::models::qwen2::{Config as ConfigBase, Model as ModelBase}; +use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe}; use candle::{DType, Device, Tensor}; use candle_examples::token_output_stream::TokenOutputStream; @@ -16,6 +17,20 @@ use candle_transformers::generation::LogitsProcessor; use hf_hub::{api::sync::Api, Repo, RepoType}; use tokenizers::Tokenizer; +enum Model { + Base(ModelBase), + Moe(ModelMoe), +} + +impl Model { + fn forward(&mut self, xs: &Tensor, s: usize) -> candle::Result<Tensor> { + match self { + Self::Moe(ref mut m) => m.forward(xs, s), + Self::Base(ref mut m) => m.forward(xs, s), + } + } +} + struct TextGeneration { model: Model, device: Device, @@ -127,6 +142,8 @@ enum WhichModel { W14b, #[value(name = "72b")] W72b, + #[value(name = "moe-a2.7b")] + MoeA27b, } #[derive(Parser, Debug)] @@ -224,6 +241,7 @@ fn main() -> Result<()> { WhichModel::W7b => "7B", WhichModel::W14b => "14B", WhichModel::W72b => "72B", + WhichModel::MoeA27b => "MoE-A2.7B", }; format!("Qwen/Qwen1.5-{size}") } @@ -244,7 +262,11 @@ fn main() -> Result<()> { .collect::<Vec<_>>(), None => match args.model { WhichModel::W0_5b | WhichModel::W1_8b => vec![repo.get("model.safetensors")?], - WhichModel::W4b | WhichModel::W7b | WhichModel::W14b | WhichModel::W72b => { + WhichModel::W4b + | WhichModel::W7b + | WhichModel::W14b + | WhichModel::W72b + | WhichModel::MoeA27b => { candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")? } }, @@ -254,7 +276,6 @@ fn main() -> Result<()> { let start = std::time::Instant::now(); let config_file = repo.get("config.json")?; - let config: Config = serde_json::from_slice(&std::fs::read(config_file)?)?; let device = candle_examples::device(args.cpu)?; let dtype = if device.is_cuda() { DType::BF16 @@ -262,7 +283,16 @@ fn main() -> Result<()> { DType::F32 }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; - let model = Model::new(&config, vb)?; + let model = match args.model { + WhichModel::MoeA27b => { + let config: ConfigMoe = serde_json::from_slice(&std::fs::read(config_file)?)?; + Model::Moe(ModelMoe::new(&config, vb)?) + } + _ => { + let config: ConfigBase = serde_json::from_slice(&std::fs::read(config_file)?)?; + Model::Base(ModelBase::new(&config, vb)?) + } + }; println!("loaded the model in {:?}", start.elapsed()); |