summaryrefslogtreecommitdiff
path: root/candle-examples/examples
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2024-01-17 10:27:58 +0100
committerGitHub <noreply@github.com>2024-01-17 10:27:58 +0100
commit403680f17ddc086295fbaee316cbed22d97a519b (patch)
tree80dcffe6e929640e7f0ebfff3ba90410fd58992e /candle-examples/examples
parent5270224f407502b82fe90bc2622894ce3871b002 (diff)
downloadcandle-403680f17ddc086295fbaee316cbed22d97a519b.tar.gz
candle-403680f17ddc086295fbaee316cbed22d97a519b.tar.bz2
candle-403680f17ddc086295fbaee316cbed22d97a519b.zip
Quantized GGUF style (#1523)
* Metal quantized modifications proposal. - Add a device param, wherever needed. - Create new QMetal storage thing that implements QuantizedType. - Update everywhere needed. Fix Python. Fixing examples. Fix: fmt + clippy + stub. Moving everything around. Only missing the actual implems. Fixing everything + adding dequantized kernels. More work. Fixing matmul. Fmt + Clippy Some clippy fixes. Working state. Q2K Metal -> Bugged (also present in GGML). Q4K CPU -> Bugged (present previously, new test catch it). Q5K CPU -> Bugged (present previously). Q8_1 Both -> Never really implemented it seems Q8K metal -> Never implemented in metal Fixing Q2K bug (present in ggml). * Cleanup. * Fix the rebase. * Removing the fences speeds everything up and *is* correct this time... * Cleanup the fence. * After rebase. * Bad code removal. * Rebase after phi2 merge + fix replit default to CPU. * Making the CI happy. * More happy tests. --------- Co-authored-by: Nicolas Patry <nicolas@Nicolass-MacBook-Pro.local>
Diffstat (limited to 'candle-examples/examples')
-rw-r--r--candle-examples/examples/blip/main.rs4
-rw-r--r--candle-examples/examples/llama2-c/main.rs8
-rw-r--r--candle-examples/examples/mistral/main.rs7
-rw-r--r--candle-examples/examples/phi/main.rs16
-rw-r--r--candle-examples/examples/quantized-t5/main.rs3
-rw-r--r--candle-examples/examples/quantized/main.rs16
-rw-r--r--candle-examples/examples/replit-code/main.rs13
-rw-r--r--candle-examples/examples/stable-lm/main.rs5
-rw-r--r--candle-examples/examples/whisper/main.rs6
9 files changed, 43 insertions, 35 deletions
diff --git a/candle-examples/examples/blip/main.rs b/candle-examples/examples/blip/main.rs
index a1051a8e..15e36476 100644
--- a/candle-examples/examples/blip/main.rs
+++ b/candle-examples/examples/blip/main.rs
@@ -106,17 +106,17 @@ pub fn main() -> anyhow::Result<()> {
let config = blip::Config::image_captioning_large();
+ let device = candle_examples::device(args.cpu)?;
let (image_embeds, device, mut model) = if args.quantized {
let device = Device::Cpu;
let image = load_image(args.image)?.to_device(&device)?;
println!("loaded image {image:?}");
- let vb = quantized_blip::VarBuilder::from_gguf(model_file)?;
+ let vb = quantized_blip::VarBuilder::from_gguf(model_file, &device)?;
let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
(image_embeds, device, Model::Q(model))
} else {
- let device = candle_examples::device(args.cpu)?;
let image = load_image(args.image)?.to_device(&device)?;
println!("loaded image {image:?}");
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs
index 0ceb27af..9d42dcc8 100644
--- a/candle-examples/examples/llama2-c/main.rs
+++ b/candle-examples/examples/llama2-c/main.rs
@@ -262,7 +262,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
.extension()
.map_or(false, |v| v == "safetensors");
let (model, config) = if is_gguf {
- let vb = qmodel::VarBuilder::from_gguf(config_path)?;
+ let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?;
let (_vocab_size, dim) = vb
.get_no_shape("model.embed_tokens.weight")?
.shape()
@@ -279,13 +279,13 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
(config.seq_len, config.head_size() / 2),
"rot.freq_cis_real",
)?
- .dequantize(&candle::Device::Cpu)?;
+ .dequantize(&device)?;
let freq_cis_imag = vb
.get(
(config.seq_len, config.head_size() / 2),
"rot.freq_cis_imag",
)?
- .dequantize(&candle::Device::Cpu)?;
+ .dequantize(&device)?;
let fake_vb = candle_nn::VarBuilder::from_tensors(
[
@@ -295,7 +295,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
.into_iter()
.collect(),
candle::DType::F32,
- &candle::Device::Cpu,
+ &device,
);
let cache = model::Cache::new(true, &config, fake_vb)?;
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
diff --git a/candle-examples/examples/mistral/main.rs b/candle-examples/examples/mistral/main.rs
index 5ed5e5cb..bad86098 100644
--- a/candle-examples/examples/mistral/main.rs
+++ b/candle-examples/examples/mistral/main.rs
@@ -244,13 +244,14 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let config = Config::config_7b_v0_1(args.use_flash_attn);
+ let device = candle_examples::device(args.cpu)?;
let (model, device) = if args.quantized {
let filename = &filenames[0];
- let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
+ let vb =
+ candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
let model = QMistral::new(&config, vb)?;
- (Model::Quantized(model), Device::Cpu)
+ (Model::Quantized(model), device)
} else {
- let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs
index 69eed84f..39f4fd58 100644
--- a/candle-examples/examples/phi/main.rs
+++ b/candle-examples/examples/phi/main.rs
@@ -307,18 +307,21 @@ fn main() -> Result<()> {
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
};
- let (model, device) = if args.quantized {
- let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?;
+ let device = candle_examples::device(args.cpu)?;
+ let model = if args.quantized {
let config = config();
+ let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
+ &filenames[0],
+ &device,
+ )?;
let model = match args.model {
WhichModel::V2 | WhichModel::V2Old => QMixFormer::new_v2(&config, vb)?,
_ => QMixFormer::new(&config, vb)?,
};
- (Model::Quantized(model), Device::Cpu)
+ Model::Quantized(model)
} else {
- let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
- let model = match args.model {
+ match args.model {
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => {
let config_filename = repo.get("config.json")?;
let config = std::fs::read_to_string(config_filename)?;
@@ -334,8 +337,7 @@ fn main() -> Result<()> {
let config = config();
Model::MixFormer(MixFormer::new(&config, vb)?)
}
- };
- (model, device)
+ }
};
println!("loaded the model in {:?}", start.elapsed());
diff --git a/candle-examples/examples/quantized-t5/main.rs b/candle-examples/examples/quantized-t5/main.rs
index 0ea2e0bd..ed3f1030 100644
--- a/candle-examples/examples/quantized-t5/main.rs
+++ b/candle-examples/examples/quantized-t5/main.rs
@@ -132,7 +132,8 @@ impl T5ModelBuilder {
}
pub fn build_model(&self) -> Result<t5::T5ForConditionalGeneration> {
- let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?;
+ let device = Device::Cpu;
+ let vb = t5::VarBuilder::from_gguf(&self.weights_filename, &device)?;
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
}
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs
index bfc6de53..34c44233 100644
--- a/candle-examples/examples/quantized/main.rs
+++ b/candle-examples/examples/quantized/main.rs
@@ -9,7 +9,7 @@ use std::io::Write;
use tokenizers::Tokenizer;
use candle::quantized::{ggml_file, gguf_file};
-use candle::{Device, Tensor};
+use candle::Tensor;
use candle_transformers::generation::LogitsProcessor;
use candle_examples::token_output_stream::TokenOutputStream;
@@ -361,6 +361,7 @@ fn main() -> anyhow::Result<()> {
let model_path = args.model()?;
let mut file = std::fs::File::open(&model_path)?;
let start = std::time::Instant::now();
+ let device = candle_examples::device(false)?;
let mut model = match model_path.extension().and_then(|v| v.to_str()) {
Some("gguf") => {
@@ -369,7 +370,7 @@ fn main() -> anyhow::Result<()> {
for (_, tensor) in model.tensor_infos.iter() {
let elem_count = tensor.shape.elem_count();
total_size_in_bytes +=
- elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size();
+ elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
}
println!(
"loaded {:?} tensors ({}) in {:.2}s",
@@ -377,15 +378,16 @@ fn main() -> anyhow::Result<()> {
&format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(),
);
- ModelWeights::from_gguf(model, &mut file)?
+ ModelWeights::from_gguf(model, &mut file, &device)?
}
Some("ggml" | "bin") | Some(_) | None => {
- let model = ggml_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
+ let model = ggml_file::Content::read(&mut file, &device)
+ .map_err(|e| e.with_path(model_path))?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensors.iter() {
let elem_count = tensor.shape().elem_count();
total_size_in_bytes +=
- elem_count * tensor.dtype().type_size() / tensor.dtype().blck_size();
+ elem_count * tensor.dtype().type_size() / tensor.dtype().block_size();
}
println!(
"loaded {:?} tensors ({}) in {:.2}s",
@@ -486,7 +488,7 @@ fn main() -> anyhow::Result<()> {
let start_prompt_processing = std::time::Instant::now();
let mut next_token = {
- let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
+ let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?;
logits_processor.sample(&logits)?
@@ -507,7 +509,7 @@ fn main() -> anyhow::Result<()> {
let start_post_prompt = std::time::Instant::now();
let mut sampled = 0;
for index in 0..to_sample {
- let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
+ let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, prompt_tokens.len() + index)?;
let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
diff --git a/candle-examples/examples/replit-code/main.rs b/candle-examples/examples/replit-code/main.rs
index 0f72b862..b7f767b9 100644
--- a/candle-examples/examples/replit-code/main.rs
+++ b/candle-examples/examples/replit-code/main.rs
@@ -236,16 +236,15 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
+ let device = candle_examples::device(args.cpu)?;
let config = Config::replit_code_v1_5_3b();
- let (model, device) = if args.quantized {
- let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
- let model = Model::Q(Q::new(&config, vb.pp("transformer"))?);
- (model, Device::Cpu)
+ let model = if args.quantized {
+ let vb =
+ candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename, &device)?;
+ Model::Q(Q::new(&config, vb.pp("transformer"))?)
} else {
- let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
- let model = Model::M(M::new(&config, vb.pp("transformer"))?);
- (model, device)
+ Model::M(M::new(&config, vb.pp("transformer"))?)
};
println!("loaded the model in {:?}", start.elapsed());
diff --git a/candle-examples/examples/stable-lm/main.rs b/candle-examples/examples/stable-lm/main.rs
index 0535aa70..ccd924a4 100644
--- a/candle-examples/examples/stable-lm/main.rs
+++ b/candle-examples/examples/stable-lm/main.rs
@@ -234,13 +234,14 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let config = Config::stablelm_3b_4e1t(args.use_flash_attn);
+ let device = candle_examples::device(args.cpu)?;
let (model, device) = if args.quantized {
let filename = &filenames[0];
- let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
+ let vb =
+ candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
let model = QStableLM::new(&config, vb)?;
(Model::Quantized(model), Device::Cpu)
} else {
- let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs
index 5be81f2d..6ea34613 100644
--- a/candle-examples/examples/whisper/main.rs
+++ b/candle-examples/examples/whisper/main.rs
@@ -557,8 +557,10 @@ fn main() -> Result<()> {
println!("loaded mel: {:?}", mel.dims());
let mut model = if args.quantized {
- let vb =
- candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?;
+ let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
+ &weights_filename,
+ &device,
+ )?;
Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?)
} else {
let vb =