diff options
author | Nicolas Patry <patry.nicolas@protonmail.com> | 2024-01-17 10:27:58 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-17 10:27:58 +0100 |
commit | 403680f17ddc086295fbaee316cbed22d97a519b (patch) | |
tree | 80dcffe6e929640e7f0ebfff3ba90410fd58992e /candle-examples/examples/llama2-c | |
parent | 5270224f407502b82fe90bc2622894ce3871b002 (diff) | |
download | candle-403680f17ddc086295fbaee316cbed22d97a519b.tar.gz candle-403680f17ddc086295fbaee316cbed22d97a519b.tar.bz2 candle-403680f17ddc086295fbaee316cbed22d97a519b.zip |
Quantized GGUF style (#1523)
* Metal quantized modifications proposal.
- Add a device param, wherever needed.
- Create new QMetal storage thing that implements QuantizedType.
- Update everywhere needed.
Fix Python.
Fixing examples.
Fix: fmt + clippy + stub.
Moving everything around.
Only missing the actual implems.
Fixing everything + adding dequantized kernels.
More work.
Fixing matmul.
Fmt + Clippy
Some clippy fixes.
Working state.
Q2K Metal -> Bugged (also present in GGML).
Q4K CPU -> Bugged (present previously, new test catch it).
Q5K CPU -> Bugged (present previously).
Q8_1 Both -> Never really implemented it seems
Q8K metal -> Never implemented in metal
Fixing Q2K bug (present in ggml).
* Cleanup.
* Fix the rebase.
* Removing the fences speeds everything up and *is* correct this time...
* Cleanup the fence.
* After rebase.
* Bad code removal.
* Rebase after phi2 merge + fix replit default to CPU.
* Making the CI happy.
* More happy tests.
---------
Co-authored-by: Nicolas Patry <nicolas@Nicolass-MacBook-Pro.local>
Diffstat (limited to 'candle-examples/examples/llama2-c')
-rw-r--r-- | candle-examples/examples/llama2-c/main.rs | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index 0ceb27af..9d42dcc8 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -262,7 +262,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { .extension() .map_or(false, |v| v == "safetensors"); let (model, config) = if is_gguf { - let vb = qmodel::VarBuilder::from_gguf(config_path)?; + let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?; let (_vocab_size, dim) = vb .get_no_shape("model.embed_tokens.weight")? .shape() @@ -279,13 +279,13 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { (config.seq_len, config.head_size() / 2), "rot.freq_cis_real", )? - .dequantize(&candle::Device::Cpu)?; + .dequantize(&device)?; let freq_cis_imag = vb .get( (config.seq_len, config.head_size() / 2), "rot.freq_cis_imag", )? - .dequantize(&candle::Device::Cpu)?; + .dequantize(&device)?; let fake_vb = candle_nn::VarBuilder::from_tensors( [ @@ -295,7 +295,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { .into_iter() .collect(), candle::DType::F32, - &candle::Device::Cpu, + &device, ); let cache = model::Cache::new(true, &config, fake_vb)?; let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?); |