summaryrefslogtreecommitdiff
path: root/candle-examples/examples
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples')
-rw-r--r--candle-examples/examples/qwen/README.md27
-rw-r--r--candle-examples/examples/qwen/main.rs38
2 files changed, 61 insertions, 4 deletions
diff --git a/candle-examples/examples/qwen/README.md b/candle-examples/examples/qwen/README.md
new file mode 100644
index 00000000..44a50b72
--- /dev/null
+++ b/candle-examples/examples/qwen/README.md
@@ -0,0 +1,27 @@
+# candle-qwen: large language model series from Alibaba Cloud
+
+Qwen 1.5 is a series of large language models that provide strong performances
+on English and Chinese.
+
+- [Blog post](https://qwenlm.github.io/blog/qwen1.5/) introducing Qwen1.5.
+- [Model card](https://huggingface.co/Qwen/Qwen1.5-0.5B) on the HuggingFace Hub.
+- [Blog post](https://qwenlm.github.io/blog/qwen-moe/) for the
+ mixture-of-experts (MoE) variant.
+
+## Running the example
+
+```bash
+$ cargo run --example qwen --release -- --prompt "Hello there "
+```
+
+Various model sizes are available via the `--model` argument, including the MoE
+variant.
+
+```bash
+$ cargo run --example qwen --release -- --prompt "Hello there " --model moe-a2.7b --prompt 'def print_prime(n: int): '
+def print_prime(n: int): # n is the number of primes to be printed
+ for i in range(2, n + 1):
+ if all(i % j != 0 for j in range(2, i)):
+ print(i)
+```
+
diff --git a/candle-examples/examples/qwen/main.rs b/candle-examples/examples/qwen/main.rs
index d040d4b0..a203ad8e 100644
--- a/candle-examples/examples/qwen/main.rs
+++ b/candle-examples/examples/qwen/main.rs
@@ -7,7 +7,8 @@ extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
-use candle_transformers::models::qwen2::{Config, Model};
+use candle_transformers::models::qwen2::{Config as ConfigBase, Model as ModelBase};
+use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
@@ -16,6 +17,20 @@ use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
+enum Model {
+ Base(ModelBase),
+ Moe(ModelMoe),
+}
+
+impl Model {
+ fn forward(&mut self, xs: &Tensor, s: usize) -> candle::Result<Tensor> {
+ match self {
+ Self::Moe(ref mut m) => m.forward(xs, s),
+ Self::Base(ref mut m) => m.forward(xs, s),
+ }
+ }
+}
+
struct TextGeneration {
model: Model,
device: Device,
@@ -127,6 +142,8 @@ enum WhichModel {
W14b,
#[value(name = "72b")]
W72b,
+ #[value(name = "moe-a2.7b")]
+ MoeA27b,
}
#[derive(Parser, Debug)]
@@ -224,6 +241,7 @@ fn main() -> Result<()> {
WhichModel::W7b => "7B",
WhichModel::W14b => "14B",
WhichModel::W72b => "72B",
+ WhichModel::MoeA27b => "MoE-A2.7B",
};
format!("Qwen/Qwen1.5-{size}")
}
@@ -244,7 +262,11 @@ fn main() -> Result<()> {
.collect::<Vec<_>>(),
None => match args.model {
WhichModel::W0_5b | WhichModel::W1_8b => vec![repo.get("model.safetensors")?],
- WhichModel::W4b | WhichModel::W7b | WhichModel::W14b | WhichModel::W72b => {
+ WhichModel::W4b
+ | WhichModel::W7b
+ | WhichModel::W14b
+ | WhichModel::W72b
+ | WhichModel::MoeA27b => {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}
},
@@ -254,7 +276,6 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let config_file = repo.get("config.json")?;
- let config: Config = serde_json::from_slice(&std::fs::read(config_file)?)?;
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
@@ -262,7 +283,16 @@ fn main() -> Result<()> {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
- let model = Model::new(&config, vb)?;
+ let model = match args.model {
+ WhichModel::MoeA27b => {
+ let config: ConfigMoe = serde_json::from_slice(&std::fs::read(config_file)?)?;
+ Model::Moe(ModelMoe::new(&config, vb)?)
+ }
+ _ => {
+ let config: ConfigBase = serde_json::from_slice(&std::fs::read(config_file)?)?;
+ Model::Base(ModelBase::new(&config, vb)?)
+ }
+ };
println!("loaded the model in {:?}", start.elapsed());