diff options
author | Zhuo Jinggang <jg.zhuo@outlook.com> | 2024-07-12 16:00:03 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-07-12 10:00:03 +0200 |
commit | c63048d3748649c6f13148eb01e6d812d897a0d2 (patch) | |
tree | 275f50476521bf47bb89530dd822a45ae776e6d3 /candle-examples | |
parent | a226a9736baee550b01de53cb3e416d3d94e69d3 (diff) | |
download | candle-c63048d3748649c6f13148eb01e6d812d897a0d2.tar.gz candle-c63048d3748649c6f13148eb01e6d812d897a0d2.tar.bz2 candle-c63048d3748649c6f13148eb01e6d812d897a0d2.zip |
add quantized qwen2 (#2329)
* add quantized version of qwen2 and corresponding example for qwen2-instruct
* fix quantized qwen2 clippy error
Diffstat (limited to 'candle-examples')
-rw-r--r-- | candle-examples/examples/quantized-qwen2-instruct/README.md | 11 | ||||
-rw-r--r-- | candle-examples/examples/quantized-qwen2-instruct/main.rs | 306 |
2 files changed, 317 insertions, 0 deletions
diff --git a/candle-examples/examples/quantized-qwen2-instruct/README.md b/candle-examples/examples/quantized-qwen2-instruct/README.md new file mode 100644 index 00000000..8129b3fc --- /dev/null +++ b/candle-examples/examples/quantized-qwen2-instruct/README.md @@ -0,0 +1,11 @@ +# candle-quantized-qwen2-instruct + +[Qwen2]((https://qwenlm.github.io/blog/qwen2/)) is an upgraded version of Qwen1.5, released by Alibaba Cloud. + +## Running the example + +```bash +cargo run --example quantized-qwen2-instruct --release -- --prompt "Write a function to count prime numbers up to N." +``` + +0.5b, 1.5b, 7b and 72b models are available via `--model` argument. diff --git a/candle-examples/examples/quantized-qwen2-instruct/main.rs b/candle-examples/examples/quantized-qwen2-instruct/main.rs new file mode 100644 index 00000000..1bd230e0 --- /dev/null +++ b/candle-examples/examples/quantized-qwen2-instruct/main.rs @@ -0,0 +1,306 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use clap::{Parser, ValueEnum}; +use std::io::Write; +use tokenizers::Tokenizer; + +use candle::quantized::gguf_file; +use candle::Tensor; +use candle_transformers::generation::{LogitsProcessor, Sampling}; + +use candle_examples::token_output_stream::TokenOutputStream; +use candle_transformers::models::quantized_qwen2::ModelWeights as Qwen2; + +const DEFAULT_PROMPT: &str = "Write a function to count prime numbers up to N. "; + +#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] +enum Which { + #[value(name = "0.5b")] + W2_0_5b, + #[value(name = "1.5b")] + W2_1_5b, + #[value(name = "7b")] + W2_7b, + #[value(name = "72b")] + W2_72b, +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// GGUF file to load, typically a .gguf file generated by the quantize command from llama.cpp + #[arg(long)] + model: Option<String>, + + /// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way + /// and 'chat' for an interactive model where history of previous prompts and generated tokens + /// is preserved. + #[arg(long)] + prompt: Option<String>, + + /// The length of the sample to generate (in tokens). + #[arg(short = 'n', long, default_value_t = 1000)] + sample_len: usize, + + /// The tokenizer config in json format. + #[arg(long)] + tokenizer: Option<String>, + + /// The temperature used to generate samples, use 0 for greedy sampling. + #[arg(long, default_value_t = 0.8)] + temperature: f64, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option<f64>, + + /// Only sample among the top K samples. + #[arg(long)] + top_k: Option<usize>, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + /// Process prompt elements separately. + #[arg(long)] + split_prompt: bool, + + /// Run on CPU rather than GPU even if a GPU is available. + #[arg(long)] + cpu: bool, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.1)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, + + /// The model size to use. + #[arg(long, default_value = "0.5b")] + which: Which, +} + +impl Args { + fn tokenizer(&self) -> anyhow::Result<Tokenizer> { + let tokenizer_path = match &self.tokenizer { + Some(config) => std::path::PathBuf::from(config), + None => { + let api = hf_hub::api::sync::Api::new()?; + let repo = match self.which { + Which::W2_0_5b => "Qwen/Qwen2-0.5B-Instruct", + Which::W2_1_5b => "Qwen/Qwen2-1.5B-Instruct", + Which::W2_7b => "Qwen/Qwen2-7B-Instruct", + Which::W2_72b => "Qwen/Qwen2-72B-Instruct", + }; + let api = api.model(repo.to_string()); + api.get("tokenizer.json")? + } + }; + Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg) + } + + fn model(&self) -> anyhow::Result<std::path::PathBuf> { + let model_path = match &self.model { + Some(config) => std::path::PathBuf::from(config), + None => { + let (repo, filename, revision) = match self.which { + Which::W2_0_5b => ( + "Qwen/Qwen2-0.5B-Instruct-GGUF", + "qwen2-0_5b-instruct-q4_0.gguf", + "main", + ), + Which::W2_1_5b => ( + "Qwen/Qwen2-1.5B-Instruct-GGUF", + "qwen2-1_5b-instruct-q4_0.gguf", + "main", + ), + Which::W2_7b => ( + "Qwen/Qwen2-7B-Instruct-GGUF", + "qwen2-7b-instruct-q4_0.gguf", + "main", + ), + Which::W2_72b => ( + "Qwen/Qwen2-72B-Instruct-GGUF", + "qwen2-72b-instruct-q4_0.gguf", + "main", + ), + }; + let api = hf_hub::api::sync::Api::new()?; + api.repo(hf_hub::Repo::with_revision( + repo.to_string(), + hf_hub::RepoType::Model, + revision.to_string(), + )) + .get(filename)? + } + }; + Ok(model_path) + } +} + +fn format_size(size_in_bytes: usize) -> String { + if size_in_bytes < 1_000 { + format!("{}B", size_in_bytes) + } else if size_in_bytes < 1_000_000 { + format!("{:.2}KB", size_in_bytes as f64 / 1e3) + } else if size_in_bytes < 1_000_000_000 { + format!("{:.2}MB", size_in_bytes as f64 / 1e6) + } else { + format!("{:.2}GB", size_in_bytes as f64 / 1e9) + } +} + +fn main() -> anyhow::Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + println!( + "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", + args.temperature, args.repeat_penalty, args.repeat_last_n + ); + + let model_path = args.model()?; + let mut file = std::fs::File::open(&model_path)?; + let start = std::time::Instant::now(); + let device = candle_examples::device(args.cpu)?; + + let mut model = { + let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?; + let mut total_size_in_bytes = 0; + for (_, tensor) in model.tensor_infos.iter() { + let elem_count = tensor.shape.elem_count(); + total_size_in_bytes += + elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size(); + } + println!( + "loaded {:?} tensors ({}) in {:.2}s", + model.tensor_infos.len(), + &format_size(total_size_in_bytes), + start.elapsed().as_secs_f32(), + ); + Qwen2::from_gguf(model, &mut file, &device)? + }; + println!("model built"); + + let tokenizer = args.tokenizer()?; + let mut tos = TokenOutputStream::new(tokenizer); + let prompt_str = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string()); + let prompt_str = format!( + "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", + prompt_str + ); + print!("formatted instruct prompt: {}", &prompt_str); + let tokens = tos + .tokenizer() + .encode(prompt_str, true) + .map_err(anyhow::Error::msg)?; + let tokens = tokens.get_ids(); + let to_sample = args.sample_len.saturating_sub(1); + let mut all_tokens = vec![]; + let mut logits_processor = { + let temperature = args.temperature; + let sampling = if temperature <= 0. { + Sampling::ArgMax + } else { + match (args.top_k, args.top_p) { + (None, None) => Sampling::All { temperature }, + (Some(k), None) => Sampling::TopK { k, temperature }, + (None, Some(p)) => Sampling::TopP { p, temperature }, + (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature }, + } + }; + LogitsProcessor::from_sampling(args.seed, sampling) + }; + let start_prompt_processing = std::time::Instant::now(); + let mut next_token = if !args.split_prompt { + let input = Tensor::new(tokens, &device)?.unsqueeze(0)?; + let logits = model.forward(&input, 0)?; + let logits = logits.squeeze(0)?; + logits_processor.sample(&logits)? + } else { + let mut next_token = 0; + for (pos, token) in tokens.iter().enumerate() { + let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?; + let logits = model.forward(&input, pos)?; + let logits = logits.squeeze(0)?; + next_token = logits_processor.sample(&logits)? + } + next_token + }; + let prompt_dt = start_prompt_processing.elapsed(); + all_tokens.push(next_token); + if let Some(t) = tos.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + let eos_token = *tos.tokenizer().get_vocab(true).get("<|im_end|>").unwrap(); + let start_post_prompt = std::time::Instant::now(); + let mut sampled = 0; + for index in 0..to_sample { + let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?; + let logits = model.forward(&input, tokens.len() + index)?; + let logits = logits.squeeze(0)?; + let logits = if args.repeat_penalty == 1. { + logits + } else { + let start_at = all_tokens.len().saturating_sub(args.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + args.repeat_penalty, + &all_tokens[start_at..], + )? + }; + next_token = logits_processor.sample(&logits)?; + all_tokens.push(next_token); + if let Some(t) = tos.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + sampled += 1; + if next_token == eos_token { + break; + }; + } + if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? { + print!("{rest}"); + } + std::io::stdout().flush()?; + let dt = start_post_prompt.elapsed(); + println!( + "\n\n{:4} prompt tokens processed: {:.2} token/s", + tokens.len(), + tokens.len() as f64 / prompt_dt.as_secs_f64(), + ); + println!( + "{sampled:4} tokens generated: {:.2} token/s", + sampled as f64 / dt.as_secs_f64(), + ); + Ok(()) +} |