summaryrefslogtreecommitdiff
path: root/candle-examples
diff options
context:
space:
mode:
authorZhuo Jinggang <jg.zhuo@outlook.com>2024-07-12 16:00:03 +0800
committerGitHub <noreply@github.com>2024-07-12 10:00:03 +0200
commitc63048d3748649c6f13148eb01e6d812d897a0d2 (patch)
tree275f50476521bf47bb89530dd822a45ae776e6d3 /candle-examples
parenta226a9736baee550b01de53cb3e416d3d94e69d3 (diff)
downloadcandle-c63048d3748649c6f13148eb01e6d812d897a0d2.tar.gz
candle-c63048d3748649c6f13148eb01e6d812d897a0d2.tar.bz2
candle-c63048d3748649c6f13148eb01e6d812d897a0d2.zip
add quantized qwen2 (#2329)
* add quantized version of qwen2 and corresponding example for qwen2-instruct * fix quantized qwen2 clippy error
Diffstat (limited to 'candle-examples')
-rw-r--r--candle-examples/examples/quantized-qwen2-instruct/README.md11
-rw-r--r--candle-examples/examples/quantized-qwen2-instruct/main.rs306
2 files changed, 317 insertions, 0 deletions
diff --git a/candle-examples/examples/quantized-qwen2-instruct/README.md b/candle-examples/examples/quantized-qwen2-instruct/README.md
new file mode 100644
index 00000000..8129b3fc
--- /dev/null
+++ b/candle-examples/examples/quantized-qwen2-instruct/README.md
@@ -0,0 +1,11 @@
+# candle-quantized-qwen2-instruct
+
+[Qwen2]((https://qwenlm.github.io/blog/qwen2/)) is an upgraded version of Qwen1.5, released by Alibaba Cloud.
+
+## Running the example
+
+```bash
+cargo run --example quantized-qwen2-instruct --release -- --prompt "Write a function to count prime numbers up to N."
+```
+
+0.5b, 1.5b, 7b and 72b models are available via `--model` argument.
diff --git a/candle-examples/examples/quantized-qwen2-instruct/main.rs b/candle-examples/examples/quantized-qwen2-instruct/main.rs
new file mode 100644
index 00000000..1bd230e0
--- /dev/null
+++ b/candle-examples/examples/quantized-qwen2-instruct/main.rs
@@ -0,0 +1,306 @@
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
+use clap::{Parser, ValueEnum};
+use std::io::Write;
+use tokenizers::Tokenizer;
+
+use candle::quantized::gguf_file;
+use candle::Tensor;
+use candle_transformers::generation::{LogitsProcessor, Sampling};
+
+use candle_examples::token_output_stream::TokenOutputStream;
+use candle_transformers::models::quantized_qwen2::ModelWeights as Qwen2;
+
+const DEFAULT_PROMPT: &str = "Write a function to count prime numbers up to N. ";
+
+#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
+enum Which {
+ #[value(name = "0.5b")]
+ W2_0_5b,
+ #[value(name = "1.5b")]
+ W2_1_5b,
+ #[value(name = "7b")]
+ W2_7b,
+ #[value(name = "72b")]
+ W2_72b,
+}
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ /// GGUF file to load, typically a .gguf file generated by the quantize command from llama.cpp
+ #[arg(long)]
+ model: Option<String>,
+
+ /// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way
+ /// and 'chat' for an interactive model where history of previous prompts and generated tokens
+ /// is preserved.
+ #[arg(long)]
+ prompt: Option<String>,
+
+ /// The length of the sample to generate (in tokens).
+ #[arg(short = 'n', long, default_value_t = 1000)]
+ sample_len: usize,
+
+ /// The tokenizer config in json format.
+ #[arg(long)]
+ tokenizer: Option<String>,
+
+ /// The temperature used to generate samples, use 0 for greedy sampling.
+ #[arg(long, default_value_t = 0.8)]
+ temperature: f64,
+
+ /// Nucleus sampling probability cutoff.
+ #[arg(long)]
+ top_p: Option<f64>,
+
+ /// Only sample among the top K samples.
+ #[arg(long)]
+ top_k: Option<usize>,
+
+ /// The seed to use when generating random samples.
+ #[arg(long, default_value_t = 299792458)]
+ seed: u64,
+
+ /// Enable tracing (generates a trace-timestamp.json file).
+ #[arg(long)]
+ tracing: bool,
+
+ /// Process prompt elements separately.
+ #[arg(long)]
+ split_prompt: bool,
+
+ /// Run on CPU rather than GPU even if a GPU is available.
+ #[arg(long)]
+ cpu: bool,
+
+ /// Penalty to be applied for repeating tokens, 1. means no penalty.
+ #[arg(long, default_value_t = 1.1)]
+ repeat_penalty: f32,
+
+ /// The context size to consider for the repeat penalty.
+ #[arg(long, default_value_t = 64)]
+ repeat_last_n: usize,
+
+ /// The model size to use.
+ #[arg(long, default_value = "0.5b")]
+ which: Which,
+}
+
+impl Args {
+ fn tokenizer(&self) -> anyhow::Result<Tokenizer> {
+ let tokenizer_path = match &self.tokenizer {
+ Some(config) => std::path::PathBuf::from(config),
+ None => {
+ let api = hf_hub::api::sync::Api::new()?;
+ let repo = match self.which {
+ Which::W2_0_5b => "Qwen/Qwen2-0.5B-Instruct",
+ Which::W2_1_5b => "Qwen/Qwen2-1.5B-Instruct",
+ Which::W2_7b => "Qwen/Qwen2-7B-Instruct",
+ Which::W2_72b => "Qwen/Qwen2-72B-Instruct",
+ };
+ let api = api.model(repo.to_string());
+ api.get("tokenizer.json")?
+ }
+ };
+ Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)
+ }
+
+ fn model(&self) -> anyhow::Result<std::path::PathBuf> {
+ let model_path = match &self.model {
+ Some(config) => std::path::PathBuf::from(config),
+ None => {
+ let (repo, filename, revision) = match self.which {
+ Which::W2_0_5b => (
+ "Qwen/Qwen2-0.5B-Instruct-GGUF",
+ "qwen2-0_5b-instruct-q4_0.gguf",
+ "main",
+ ),
+ Which::W2_1_5b => (
+ "Qwen/Qwen2-1.5B-Instruct-GGUF",
+ "qwen2-1_5b-instruct-q4_0.gguf",
+ "main",
+ ),
+ Which::W2_7b => (
+ "Qwen/Qwen2-7B-Instruct-GGUF",
+ "qwen2-7b-instruct-q4_0.gguf",
+ "main",
+ ),
+ Which::W2_72b => (
+ "Qwen/Qwen2-72B-Instruct-GGUF",
+ "qwen2-72b-instruct-q4_0.gguf",
+ "main",
+ ),
+ };
+ let api = hf_hub::api::sync::Api::new()?;
+ api.repo(hf_hub::Repo::with_revision(
+ repo.to_string(),
+ hf_hub::RepoType::Model,
+ revision.to_string(),
+ ))
+ .get(filename)?
+ }
+ };
+ Ok(model_path)
+ }
+}
+
+fn format_size(size_in_bytes: usize) -> String {
+ if size_in_bytes < 1_000 {
+ format!("{}B", size_in_bytes)
+ } else if size_in_bytes < 1_000_000 {
+ format!("{:.2}KB", size_in_bytes as f64 / 1e3)
+ } else if size_in_bytes < 1_000_000_000 {
+ format!("{:.2}MB", size_in_bytes as f64 / 1e6)
+ } else {
+ format!("{:.2}GB", size_in_bytes as f64 / 1e9)
+ }
+}
+
+fn main() -> anyhow::Result<()> {
+ use tracing_chrome::ChromeLayerBuilder;
+ use tracing_subscriber::prelude::*;
+
+ let args = Args::parse();
+ let _guard = if args.tracing {
+ let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
+ tracing_subscriber::registry().with(chrome_layer).init();
+ Some(guard)
+ } else {
+ None
+ };
+
+ println!(
+ "avx: {}, neon: {}, simd128: {}, f16c: {}",
+ candle::utils::with_avx(),
+ candle::utils::with_neon(),
+ candle::utils::with_simd128(),
+ candle::utils::with_f16c()
+ );
+ println!(
+ "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
+ args.temperature, args.repeat_penalty, args.repeat_last_n
+ );
+
+ let model_path = args.model()?;
+ let mut file = std::fs::File::open(&model_path)?;
+ let start = std::time::Instant::now();
+ let device = candle_examples::device(args.cpu)?;
+
+ let mut model = {
+ let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
+ let mut total_size_in_bytes = 0;
+ for (_, tensor) in model.tensor_infos.iter() {
+ let elem_count = tensor.shape.elem_count();
+ total_size_in_bytes +=
+ elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
+ }
+ println!(
+ "loaded {:?} tensors ({}) in {:.2}s",
+ model.tensor_infos.len(),
+ &format_size(total_size_in_bytes),
+ start.elapsed().as_secs_f32(),
+ );
+ Qwen2::from_gguf(model, &mut file, &device)?
+ };
+ println!("model built");
+
+ let tokenizer = args.tokenizer()?;
+ let mut tos = TokenOutputStream::new(tokenizer);
+ let prompt_str = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string());
+ let prompt_str = format!(
+ "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
+ prompt_str
+ );
+ print!("formatted instruct prompt: {}", &prompt_str);
+ let tokens = tos
+ .tokenizer()
+ .encode(prompt_str, true)
+ .map_err(anyhow::Error::msg)?;
+ let tokens = tokens.get_ids();
+ let to_sample = args.sample_len.saturating_sub(1);
+ let mut all_tokens = vec![];
+ let mut logits_processor = {
+ let temperature = args.temperature;
+ let sampling = if temperature <= 0. {
+ Sampling::ArgMax
+ } else {
+ match (args.top_k, args.top_p) {
+ (None, None) => Sampling::All { temperature },
+ (Some(k), None) => Sampling::TopK { k, temperature },
+ (None, Some(p)) => Sampling::TopP { p, temperature },
+ (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
+ }
+ };
+ LogitsProcessor::from_sampling(args.seed, sampling)
+ };
+ let start_prompt_processing = std::time::Instant::now();
+ let mut next_token = if !args.split_prompt {
+ let input = Tensor::new(tokens, &device)?.unsqueeze(0)?;
+ let logits = model.forward(&input, 0)?;
+ let logits = logits.squeeze(0)?;
+ logits_processor.sample(&logits)?
+ } else {
+ let mut next_token = 0;
+ for (pos, token) in tokens.iter().enumerate() {
+ let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
+ let logits = model.forward(&input, pos)?;
+ let logits = logits.squeeze(0)?;
+ next_token = logits_processor.sample(&logits)?
+ }
+ next_token
+ };
+ let prompt_dt = start_prompt_processing.elapsed();
+ all_tokens.push(next_token);
+ if let Some(t) = tos.next_token(next_token)? {
+ print!("{t}");
+ std::io::stdout().flush()?;
+ }
+ let eos_token = *tos.tokenizer().get_vocab(true).get("<|im_end|>").unwrap();
+ let start_post_prompt = std::time::Instant::now();
+ let mut sampled = 0;
+ for index in 0..to_sample {
+ let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
+ let logits = model.forward(&input, tokens.len() + index)?;
+ let logits = logits.squeeze(0)?;
+ let logits = if args.repeat_penalty == 1. {
+ logits
+ } else {
+ let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
+ candle_transformers::utils::apply_repeat_penalty(
+ &logits,
+ args.repeat_penalty,
+ &all_tokens[start_at..],
+ )?
+ };
+ next_token = logits_processor.sample(&logits)?;
+ all_tokens.push(next_token);
+ if let Some(t) = tos.next_token(next_token)? {
+ print!("{t}");
+ std::io::stdout().flush()?;
+ }
+ sampled += 1;
+ if next_token == eos_token {
+ break;
+ };
+ }
+ if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
+ print!("{rest}");
+ }
+ std::io::stdout().flush()?;
+ let dt = start_post_prompt.elapsed();
+ println!(
+ "\n\n{:4} prompt tokens processed: {:.2} token/s",
+ tokens.len(),
+ tokens.len() as f64 / prompt_dt.as_secs_f64(),
+ );
+ println!(
+ "{sampled:4} tokens generated: {:.2} token/s",
+ sampled as f64 / dt.as_secs_f64(),
+ );
+ Ok(())
+}