summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/quantized/k_quants.rs8
-rw-r--r--candle-core/tests/quantized_tests.rs50
-rw-r--r--candle-examples/examples/ggml/main.rs66
3 files changed, 110 insertions, 14 deletions
diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs
index f7611897..366eca1e 100644
--- a/candle-core/src/quantized/k_quants.rs
+++ b/candle-core/src/quantized/k_quants.rs
@@ -70,7 +70,7 @@ const _: () = assert!(std::mem::size_of::<BlockQ5_1>() == 24);
#[repr(C)]
pub struct BlockQ8_0 {
d: f16,
- qs: [u8; QK8_0],
+ qs: [i8; QK8_0],
}
const _: () = assert!(std::mem::size_of::<BlockQ8_0>() == 34);
@@ -476,14 +476,14 @@ impl GgmlType for BlockQ6K {
if k % QK_K != 0 {
crate::bail!("dequantize_row_q6k: {k} is not divisible by {QK_K}")
}
- for x in xs.iter() {
+ for (idx_x, x) in xs.iter().enumerate() {
let d = x.d.to_f32();
let ql = &x.ql;
let qh = &x.qh;
let sc = &x.scales;
for n in (0..QK_K).step_by(128) {
let idx = n / 128;
- let ys = &mut ys[n..];
+ let ys = &mut ys[idx_x * QK_K + n..];
let sc = &sc[8 * idx..];
let ql = &ql[64 * idx..];
let qh = &qh[32 * idx..];
@@ -663,7 +663,7 @@ impl GgmlType for BlockQ8_0 {
let id = if d != 0f32 { 1. / d } else { 0. };
ys.d = f16::from_f32(d);
for (y, &x) in ys.qs.iter_mut().zip(xs.iter()) {
- *y = f32::round(x * id) as u8
+ *y = f32::round(x * id) as i8
}
}
Ok(())
diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs
index babd71a8..9c5168bf 100644
--- a/candle-core/tests/quantized_tests.rs
+++ b/candle-core/tests/quantized_tests.rs
@@ -1,5 +1,7 @@
use candle_core::{quantized, Device, Result, Tensor};
use quantized::{k_quants, GgmlType};
+mod test_utils;
+use test_utils::to_vec2_round;
#[test]
fn quantized_matmul() -> Result<()> {
@@ -46,6 +48,54 @@ fn quantized_matmul() -> Result<()> {
}
#[test]
+fn quantized_matmul_neg() -> Result<()> {
+ let cpu = &Device::Cpu;
+ let (m, k, n) = (3, 64, 4);
+ let lhs = (0..(m * k))
+ .map(|v| v as f32 - (m * k) as f32 / 2.0)
+ .collect::<Vec<_>>();
+ let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?;
+ let mut dst = vec![42.; 3 * 4];
+ let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
+ let rhs = (0..k * n)
+ .map(|v| v as f32 - (k * n) as f32 / 3.0)
+ .collect::<Vec<_>>();
+ let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?;
+ k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
+ k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
+ assert_eq!(
+ dst,
+ &[
+ 243524.14, -19596.34, -285051.3, -549814.94, 23776.629, 21650.926, 19397.924,
+ 18366.586, -196472.1, 63011.6, 324584.56, 587901.9
+ ]
+ );
+ let mm = tensor_lhs.matmul(&tensor_rhs)?;
+ assert_eq!(
+ to_vec2_round(&mm, 0)?,
+ &[
+ [244064.0, -20128.0, -284320.0, -548512.0],
+ [23563.0, 21515.0, 19467.0, 17419.0],
+ [-196939.0, 63157.0, 323253.0, 583349.0]
+ ]
+ );
+
+ let qtensor = quantized::QTensor::new(rhs_t, (4, 64));
+ let matmul = quantized::QMatMul::from_qtensor(qtensor);
+ let res = matmul.forward(&tensor_lhs)?;
+ assert_eq!(
+ to_vec2_round(&res, 0)?,
+ &[
+ [243524.0, -19596.0, -285051.0, -549815.0],
+ [23777.0, 21651.0, 19398.0, 18367.0],
+ [-196472.0, 63012.0, 324585.0, 587902.0]
+ ]
+ );
+
+ Ok(())
+}
+
+#[test]
fn quantize_q4_0() -> Result<()> {
use k_quants::BlockQ4_0;
diff --git a/candle-examples/examples/ggml/main.rs b/candle-examples/examples/ggml/main.rs
index 912bc53a..7d6ec2ca 100644
--- a/candle-examples/examples/ggml/main.rs
+++ b/candle-examples/examples/ggml/main.rs
@@ -2,6 +2,7 @@
use clap::Parser;
use std::collections::HashMap;
use std::io::Write;
+use tokenizers::Tokenizer;
use candle::quantized::ggml_file::Content;
use candle::quantized::{QMatMul, QTensor};
@@ -52,6 +53,7 @@ struct LayerWeights {
head_dim: usize,
cos: Tensor,
sin: Tensor,
+ kv_cache: Option<(Tensor, Tensor)>,
}
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
@@ -75,7 +77,7 @@ impl LayerWeights {
Ok(rope)
}
- fn forward_attn(&self, x: &Tensor, mask: &Tensor, index_pos: usize) -> Result<Tensor> {
+ fn forward_attn(&mut self, x: &Tensor, mask: &Tensor, index_pos: usize) -> Result<Tensor> {
let (b_sz, seq_len, n_embd) = x.dims3()?;
let q = self.attention_wq.forward(x)?;
let k = self.attention_wk.forward(x)?;
@@ -94,7 +96,15 @@ impl LayerWeights {
let q = self.apply_rotary_emb(&q, index_pos)?;
let k = self.apply_rotary_emb(&k, index_pos)?;
- // TODO: KV cache.
+ let (k, v) = match &self.kv_cache {
+ None => (k, v),
+ Some((k_cache, v_cache)) => {
+ let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?;
+ let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?;
+ (k, v)
+ }
+ };
+ self.kv_cache = Some((k.clone(), v.clone()));
// If we start supporting MQA, we need to repeat the k and v tensors here.
@@ -181,6 +191,7 @@ impl ModelWeights {
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
cos: cos.clone(),
sin: sin.clone(),
+ kv_cache: None,
})
}
Ok(Self {
@@ -209,7 +220,7 @@ impl ModelWeights {
let (_b_sz, seq_len) = x.dims2()?;
let mask = self.mask(seq_len)?;
let mut layer_in = self.tok_embeddings.forward(x)?;
- for (_layer_idx, layer) in self.layers.iter().enumerate() {
+ for layer in self.layers.iter_mut() {
let x = layer_in;
let residual = &x;
let x = layer.attention_norm.forward(&x)?;
@@ -237,7 +248,7 @@ impl ModelWeights {
struct Args {
/// GGML file to load, typically a .bin file generated by the quantize command from llama.cpp
#[arg(long)]
- model: String,
+ model: Option<String>,
/// The initial prompt.
#[arg(long)]
@@ -249,7 +260,7 @@ struct Args {
/// The tokenizer config in json format.
#[arg(long)]
- tokenizer: String,
+ tokenizer: Option<String>,
/// The temperature used to generate samples.
#[arg(long)]
@@ -260,11 +271,36 @@ struct Args {
seed: u64,
}
+impl Args {
+ fn tokenizer(&self) -> anyhow::Result<Tokenizer> {
+ let tokenizer_path = match &self.tokenizer {
+ Some(config) => std::path::PathBuf::from(config),
+ None => {
+ let api = hf_hub::api::sync::Api::new()?;
+ let api = api.model("hf-internal-testing/llama-tokenizer".to_string());
+ api.get("tokenizer.json")?
+ }
+ };
+ Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)
+ }
+
+ fn model(&self) -> anyhow::Result<std::path::PathBuf> {
+ let model_path = match &self.model {
+ Some(config) => std::path::PathBuf::from(config),
+ None => {
+ let api = hf_hub::api::sync::Api::new()?;
+ let api = api.model("TheBloke/Llama-2-7B-GGML".to_string());
+ api.get("llama-2-7b.ggmlv3.q4_0.bin")?
+ }
+ };
+ Ok(model_path)
+ }
+}
+
fn main() -> anyhow::Result<()> {
- use tokenizers::Tokenizer;
let args = Args::parse();
- let mut file = std::fs::File::open(args.model)?;
+ let mut file = std::fs::File::open(&args.model()?)?;
let start = std::time::Instant::now();
let model = Content::read(&mut file)?;
@@ -293,7 +329,7 @@ fn main() -> anyhow::Result<()> {
let mut model = ModelWeights::new(model)?;
println!("model built");
- let tokenizer = Tokenizer::from_file(args.tokenizer).map_err(anyhow::Error::msg)?;
+ let tokenizer = args.tokenizer()?;
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
let mut tokens = tokenizer
.encode(prompt, true)
@@ -302,8 +338,11 @@ fn main() -> anyhow::Result<()> {
.to_vec();
let mut index_pos = 0;
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
- for _index in 0..args.sample_len {
- let context_size = tokens.len();
+ let start_gen = std::time::Instant::now();
+ let mut token_generated = 0;
+ print!("{prompt}");
+ for index in 0..args.sample_len {
+ let context_size = if index == 0 { tokens.len() } else { 1 };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &Device::Cpu)?.unsqueeze(0)?;
let logits = model.forward(&input, index_pos)?;
@@ -311,6 +350,7 @@ fn main() -> anyhow::Result<()> {
index_pos += ctxt.len();
let next_token = logits_processor.sample(&logits)?;
+ token_generated += 1;
tokens.push(next_token);
// Extracting the last token as a string is complicated, here we just apply some simple
@@ -323,5 +363,11 @@ fn main() -> anyhow::Result<()> {
std::io::stdout().flush()?;
}
}
+ let dt = start_gen.elapsed();
+ println!(
+ "\n\n{} tokens generated ({} token/s)\n",
+ token_generated,
+ token_generated as f64 / dt.as_secs_f64(),
+ );
Ok(())
}