diff options
author | Nicolas Patry <patry.nicolas@protonmail.com> | 2024-01-17 10:27:58 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-17 10:27:58 +0100 |
commit | 403680f17ddc086295fbaee316cbed22d97a519b (patch) | |
tree | 80dcffe6e929640e7f0ebfff3ba90410fd58992e /candle-examples/examples | |
parent | 5270224f407502b82fe90bc2622894ce3871b002 (diff) | |
download | candle-403680f17ddc086295fbaee316cbed22d97a519b.tar.gz candle-403680f17ddc086295fbaee316cbed22d97a519b.tar.bz2 candle-403680f17ddc086295fbaee316cbed22d97a519b.zip |
Quantized GGUF style (#1523)
* Metal quantized modifications proposal.
- Add a device param, wherever needed.
- Create new QMetal storage thing that implements QuantizedType.
- Update everywhere needed.
Fix Python.
Fixing examples.
Fix: fmt + clippy + stub.
Moving everything around.
Only missing the actual implems.
Fixing everything + adding dequantized kernels.
More work.
Fixing matmul.
Fmt + Clippy
Some clippy fixes.
Working state.
Q2K Metal -> Bugged (also present in GGML).
Q4K CPU -> Bugged (present previously, new test catch it).
Q5K CPU -> Bugged (present previously).
Q8_1 Both -> Never really implemented it seems
Q8K metal -> Never implemented in metal
Fixing Q2K bug (present in ggml).
* Cleanup.
* Fix the rebase.
* Removing the fences speeds everything up and *is* correct this time...
* Cleanup the fence.
* After rebase.
* Bad code removal.
* Rebase after phi2 merge + fix replit default to CPU.
* Making the CI happy.
* More happy tests.
---------
Co-authored-by: Nicolas Patry <nicolas@Nicolass-MacBook-Pro.local>
Diffstat (limited to 'candle-examples/examples')
-rw-r--r-- | candle-examples/examples/blip/main.rs | 4 | ||||
-rw-r--r-- | candle-examples/examples/llama2-c/main.rs | 8 | ||||
-rw-r--r-- | candle-examples/examples/mistral/main.rs | 7 | ||||
-rw-r--r-- | candle-examples/examples/phi/main.rs | 16 | ||||
-rw-r--r-- | candle-examples/examples/quantized-t5/main.rs | 3 | ||||
-rw-r--r-- | candle-examples/examples/quantized/main.rs | 16 | ||||
-rw-r--r-- | candle-examples/examples/replit-code/main.rs | 13 | ||||
-rw-r--r-- | candle-examples/examples/stable-lm/main.rs | 5 | ||||
-rw-r--r-- | candle-examples/examples/whisper/main.rs | 6 |
9 files changed, 43 insertions, 35 deletions
diff --git a/candle-examples/examples/blip/main.rs b/candle-examples/examples/blip/main.rs index a1051a8e..15e36476 100644 --- a/candle-examples/examples/blip/main.rs +++ b/candle-examples/examples/blip/main.rs @@ -106,17 +106,17 @@ pub fn main() -> anyhow::Result<()> { let config = blip::Config::image_captioning_large(); + let device = candle_examples::device(args.cpu)?; let (image_embeds, device, mut model) = if args.quantized { let device = Device::Cpu; let image = load_image(args.image)?.to_device(&device)?; println!("loaded image {image:?}"); - let vb = quantized_blip::VarBuilder::from_gguf(model_file)?; + let vb = quantized_blip::VarBuilder::from_gguf(model_file, &device)?; let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?; let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?; (image_embeds, device, Model::Q(model)) } else { - let device = candle_examples::device(args.cpu)?; let image = load_image(args.image)?.to_device(&device)?; println!("loaded image {image:?}"); diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index 0ceb27af..9d42dcc8 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -262,7 +262,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { .extension() .map_or(false, |v| v == "safetensors"); let (model, config) = if is_gguf { - let vb = qmodel::VarBuilder::from_gguf(config_path)?; + let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?; let (_vocab_size, dim) = vb .get_no_shape("model.embed_tokens.weight")? .shape() @@ -279,13 +279,13 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { (config.seq_len, config.head_size() / 2), "rot.freq_cis_real", )? - .dequantize(&candle::Device::Cpu)?; + .dequantize(&device)?; let freq_cis_imag = vb .get( (config.seq_len, config.head_size() / 2), "rot.freq_cis_imag", )? - .dequantize(&candle::Device::Cpu)?; + .dequantize(&device)?; let fake_vb = candle_nn::VarBuilder::from_tensors( [ @@ -295,7 +295,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { .into_iter() .collect(), candle::DType::F32, - &candle::Device::Cpu, + &device, ); let cache = model::Cache::new(true, &config, fake_vb)?; let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?); diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs index 5ed5e5cb..bad86098 100644 --- a/candle-examples/examples/mistral/main.rs +++ b/candle-examples/examples/mistral/main.rs @@ -244,13 +244,14 @@ fn main() -> Result<()> { let start = std::time::Instant::now(); let config = Config::config_7b_v0_1(args.use_flash_attn); + let device = candle_examples::device(args.cpu)?; let (model, device) = if args.quantized { let filename = &filenames[0]; - let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?; + let vb = + candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?; let model = QMistral::new(&config, vb)?; - (Model::Quantized(model), Device::Cpu) + (Model::Quantized(model), device) } else { - let device = candle_examples::device(args.cpu)?; let dtype = if device.is_cuda() { DType::BF16 } else { diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index 69eed84f..39f4fd58 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -307,18 +307,21 @@ fn main() -> Result<()> { WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(), WhichModel::PhiHermes => Config::phi_hermes_1_3b(), }; - let (model, device) = if args.quantized { - let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?; + let device = candle_examples::device(args.cpu)?; + let model = if args.quantized { let config = config(); + let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf( + &filenames[0], + &device, + )?; let model = match args.model { WhichModel::V2 | WhichModel::V2Old => QMixFormer::new_v2(&config, vb)?, _ => QMixFormer::new(&config, vb)?, }; - (Model::Quantized(model), Device::Cpu) + Model::Quantized(model) } else { - let device = candle_examples::device(args.cpu)?; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; - let model = match args.model { + match args.model { WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => { let config_filename = repo.get("config.json")?; let config = std::fs::read_to_string(config_filename)?; @@ -334,8 +337,7 @@ fn main() -> Result<()> { let config = config(); Model::MixFormer(MixFormer::new(&config, vb)?) } - }; - (model, device) + } }; println!("loaded the model in {:?}", start.elapsed()); diff --git a/candle-examples/examples/quantized-t5/main.rs b/candle-examples/examples/quantized-t5/main.rs index 0ea2e0bd..ed3f1030 100644 --- a/candle-examples/examples/quantized-t5/main.rs +++ b/candle-examples/examples/quantized-t5/main.rs @@ -132,7 +132,8 @@ impl T5ModelBuilder { } pub fn build_model(&self) -> Result<t5::T5ForConditionalGeneration> { - let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?; + let device = Device::Cpu; + let vb = t5::VarBuilder::from_gguf(&self.weights_filename, &device)?; Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?) } diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index bfc6de53..34c44233 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -9,7 +9,7 @@ use std::io::Write; use tokenizers::Tokenizer; use candle::quantized::{ggml_file, gguf_file}; -use candle::{Device, Tensor}; +use candle::Tensor; use candle_transformers::generation::LogitsProcessor; use candle_examples::token_output_stream::TokenOutputStream; @@ -361,6 +361,7 @@ fn main() -> anyhow::Result<()> { let model_path = args.model()?; let mut file = std::fs::File::open(&model_path)?; let start = std::time::Instant::now(); + let device = candle_examples::device(false)?; let mut model = match model_path.extension().and_then(|v| v.to_str()) { Some("gguf") => { @@ -369,7 +370,7 @@ fn main() -> anyhow::Result<()> { for (_, tensor) in model.tensor_infos.iter() { let elem_count = tensor.shape.elem_count(); total_size_in_bytes += - elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size(); + elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size(); } println!( "loaded {:?} tensors ({}) in {:.2}s", @@ -377,15 +378,16 @@ fn main() -> anyhow::Result<()> { &format_size(total_size_in_bytes), start.elapsed().as_secs_f32(), ); - ModelWeights::from_gguf(model, &mut file)? + ModelWeights::from_gguf(model, &mut file, &device)? } Some("ggml" | "bin") | Some(_) | None => { - let model = ggml_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?; + let model = ggml_file::Content::read(&mut file, &device) + .map_err(|e| e.with_path(model_path))?; let mut total_size_in_bytes = 0; for (_, tensor) in model.tensors.iter() { let elem_count = tensor.shape().elem_count(); total_size_in_bytes += - elem_count * tensor.dtype().type_size() / tensor.dtype().blck_size(); + elem_count * tensor.dtype().type_size() / tensor.dtype().block_size(); } println!( "loaded {:?} tensors ({}) in {:.2}s", @@ -486,7 +488,7 @@ fn main() -> anyhow::Result<()> { let start_prompt_processing = std::time::Instant::now(); let mut next_token = { - let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?; + let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?; let logits = model.forward(&input, 0)?; let logits = logits.squeeze(0)?; logits_processor.sample(&logits)? @@ -507,7 +509,7 @@ fn main() -> anyhow::Result<()> { let start_post_prompt = std::time::Instant::now(); let mut sampled = 0; for index in 0..to_sample { - let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?; + let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?; let logits = model.forward(&input, prompt_tokens.len() + index)?; let logits = logits.squeeze(0)?; let logits = if args.repeat_penalty == 1. { diff --git a/candle-examples/examples/replit-code/main.rs b/candle-examples/examples/replit-code/main.rs index 0f72b862..b7f767b9 100644 --- a/candle-examples/examples/replit-code/main.rs +++ b/candle-examples/examples/replit-code/main.rs @@ -236,16 +236,15 @@ fn main() -> Result<()> { let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let start = std::time::Instant::now(); + let device = candle_examples::device(args.cpu)?; let config = Config::replit_code_v1_5_3b(); - let (model, device) = if args.quantized { - let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?; - let model = Model::Q(Q::new(&config, vb.pp("transformer"))?); - (model, Device::Cpu) + let model = if args.quantized { + let vb = + candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename, &device)?; + Model::Q(Q::new(&config, vb.pp("transformer"))?) } else { - let device = candle_examples::device(args.cpu)?; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? }; - let model = Model::M(M::new(&config, vb.pp("transformer"))?); - (model, device) + Model::M(M::new(&config, vb.pp("transformer"))?) }; println!("loaded the model in {:?}", start.elapsed()); diff --git a/candle-examples/examples/stable-lm/main.rs b/candle-examples/examples/stable-lm/main.rs index 0535aa70..ccd924a4 100644 --- a/candle-examples/examples/stable-lm/main.rs +++ b/candle-examples/examples/stable-lm/main.rs @@ -234,13 +234,14 @@ fn main() -> Result<()> { let start = std::time::Instant::now(); let config = Config::stablelm_3b_4e1t(args.use_flash_attn); + let device = candle_examples::device(args.cpu)?; let (model, device) = if args.quantized { let filename = &filenames[0]; - let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?; + let vb = + candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?; let model = QStableLM::new(&config, vb)?; (Model::Quantized(model), Device::Cpu) } else { - let device = candle_examples::device(args.cpu)?; let dtype = if device.is_cuda() { DType::BF16 } else { diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 5be81f2d..6ea34613 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -557,8 +557,10 @@ fn main() -> Result<()> { println!("loaded mel: {:?}", mel.dims()); let mut model = if args.quantized { - let vb = - candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?; + let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf( + &weights_filename, + &device, + )?; Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?) } else { let vb = |