summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-16 20:15:11 +0100
committerGitHub <noreply@github.com>2023-08-16 20:15:11 +0100
commit3bedba1fcedb26e3ee850a9c1f20e320d679bf3b (patch)
tree96c55494d34f90a3273daea79503cb1abf0f3394
parentc5f45887dc32bc6575c9d55135def391b949ce98 (diff)
downloadcandle-3bedba1fcedb26e3ee850a9c1f20e320d679bf3b.tar.gz
candle-3bedba1fcedb26e3ee850a9c1f20e320d679bf3b.tar.bz2
candle-3bedba1fcedb26e3ee850a9c1f20e320d679bf3b.zip
Use a zipped iterator. (#475)
* Use a zipped iterator. * Add to/from float for q8k.
-rw-r--r--candle-core/src/quantized/k_quants.rs65
1 files changed, 54 insertions, 11 deletions
diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs
index 366eca1e..28ac896e 100644
--- a/candle-core/src/quantized/k_quants.rs
+++ b/candle-core/src/quantized/k_quants.rs
@@ -510,16 +510,59 @@ impl GgmlType for BlockQ8K {
type VecDotType = BlockQ8K;
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
- todo!()
+ unreachable!()
}
- fn from_float(_xs: &[f32], _ys: &mut [Self]) -> Result<()> {
- todo!()
+ fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
+ let k = xs.len();
+ if k % QK_K != 0 {
+ crate::bail!("quantize_row_q8k: {k} is not divisible by {QK_K}")
+ }
+ for (i, y) in ys.iter_mut().enumerate() {
+ let mut max = 0f32;
+ let mut amax = 0f32;
+ let xs = &xs[i * QK_K..(i + 1) * QK_K];
+ for &x in xs.iter() {
+ if amax < x.abs() {
+ amax = x.abs();
+ max = x;
+ }
+ }
+ if amax == 0f32 {
+ y.d = 0f32;
+ y.qs.fill(0)
+ } else {
+ let iscale = -128f32 / max;
+ for (j, q) in y.qs.iter_mut().enumerate() {
+ // ggml uses nearest_int with bit magic here, maybe we want the same
+ // but we would have to test and benchmark it.
+ let v = (iscale * xs[j]).round();
+ *q = v.min(127.) as i8
+ }
+ for j in 0..QK_K / 16 {
+ let mut sum = 0i32;
+ for ii in 0..16 {
+ sum += y.qs[j * 16 + ii] as i32
+ }
+ y.bsums[j] = sum as i16
+ }
+ y.d = 1.0 / iscale
+ }
+ }
+ Ok(())
}
- // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L533
- fn to_float(_xs: &[Self], _ys: &mut [f32]) -> Result<()> {
- todo!()
+ fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
+ let k = ys.len();
+ if k % QK_K != 0 {
+ crate::bail!("dequantize_row_q8k: {k} is not divisible by {QK_K}")
+ }
+ for (i, x) in xs.iter().enumerate() {
+ for (j, &q) in x.qs.iter().enumerate() {
+ ys[i * QK_K + j] = x.d * q as f32
+ }
+ }
+ Ok(())
}
}
@@ -601,14 +644,14 @@ impl GgmlType for BlockQ4_0 {
// Generic implementation.
let mut sumf = 0f32;
- for i in 0..nb {
+ for (xs, ys) in xs.iter().zip(ys.iter()) {
let mut sum_i = 0;
for j in 0..qk / 2 {
- let v0 = (xs[i].qs[j] & 0x0F) as i32 - 8;
- let v1 = (xs[i].qs[j] >> 4) as i32 - 8;
- sum_i += v0 * ys[i].qs[j] as i32 + v1 * ys[i].qs[j + qk / 2] as i32
+ let v0 = (xs.qs[j] & 0x0F) as i32 - 8;
+ let v1 = (xs.qs[j] >> 4) as i32 - 8;
+ sum_i += v0 * ys.qs[j] as i32 + v1 * ys.qs[j + qk / 2] as i32
}
- sumf += sum_i as f32 * f16::to_f32(xs[i].d) * f16::to_f32(ys[i].d)
+ sumf += sum_i as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d)
}
Ok(sumf)
}