summaryrefslogtreecommitdiff
path: root/candle-core
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-01-11 23:15:11 +0100
committerGitHub <noreply@github.com>2024-01-11 23:15:11 +0100
commit41915184bb3e530cc8184fdd8841c66df9285684 (patch)
tree57333a77415fa84c0cb62fa755ca7e04f46bcdd0 /candle-core
parentc1876b80415f5b84f3ea07589f359c786035fc5f (diff)
downloadcandle-41915184bb3e530cc8184fdd8841c66df9285684.tar.gz
candle-41915184bb3e530cc8184fdd8841c66df9285684.tar.bz2
candle-41915184bb3e530cc8184fdd8841c66df9285684.zip
Bugfix for dequantizing q5k layers. (#1569)
Diffstat (limited to 'candle-core')
-rw-r--r--candle-core/src/quantized/k_quants.rs8
-rw-r--r--candle-core/tests/quantized_tests.rs2
2 files changed, 5 insertions, 5 deletions
diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs
index d16289e6..6210ac1e 100644
--- a/candle-core/src/quantized/k_quants.rs
+++ b/candle-core/src/quantized/k_quants.rs
@@ -1545,13 +1545,13 @@ impl GgmlType for BlockQ5K {
let d2 = d * sc as f32;
let m2 = min * m as f32;
for (ql, qh) in ql.iter().zip(qh) {
- let to_add = if qh & u1 != 0 { 16 } else { 1 };
- y[ys_index] = d1 * ((ql & 0xF) + to_add) as f32 - m1;
+ let to_add = if qh & u1 != 0 { 16f32 } else { 0f32 };
+ y[ys_index] = d1 * ((ql & 0xF) as f32 + to_add) - m1;
ys_index += 1;
}
for (ql, qh) in ql.iter().zip(qh) {
- let to_add = if qh & u2 != 0 { 16 } else { 1 };
- y[ys_index] = d2 * ((ql >> 4) + to_add) as f32 - m2;
+ let to_add = if qh & u2 != 0 { 16f32 } else { 0f32 };
+ y[ys_index] = d2 * ((ql >> 4) as f32 + to_add) - m2;
ys_index += 1;
}
is += 2;
diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs
index e7a2ea7f..d31e77a7 100644
--- a/candle-core/tests/quantized_tests.rs
+++ b/candle-core/tests/quantized_tests.rs
@@ -407,7 +407,7 @@ fn quantize_q5k() -> Result<()> {
let dst = round_vector(&dst);
assert_eq!(
[dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]],
- [-0.499, -0.372, -0.249, 0.001, 0.279, 0.499]
+ [-0.5, -0.373, -0.25, 0.0, 0.279, 0.499]
);
let (src_big, mut dst_big) = get_test_vector(128.0, 1024);