diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-01-11 23:15:11 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-11 23:15:11 +0100 |
commit | 41915184bb3e530cc8184fdd8841c66df9285684 (patch) | |
tree | 57333a77415fa84c0cb62fa755ca7e04f46bcdd0 /candle-core | |
parent | c1876b80415f5b84f3ea07589f359c786035fc5f (diff) | |
download | candle-41915184bb3e530cc8184fdd8841c66df9285684.tar.gz candle-41915184bb3e530cc8184fdd8841c66df9285684.tar.bz2 candle-41915184bb3e530cc8184fdd8841c66df9285684.zip |
Bugfix for dequantizing q5k layers. (#1569)
Diffstat (limited to 'candle-core')
-rw-r--r-- | candle-core/src/quantized/k_quants.rs | 8 | ||||
-rw-r--r-- | candle-core/tests/quantized_tests.rs | 2 |
2 files changed, 5 insertions, 5 deletions
diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index d16289e6..6210ac1e 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -1545,13 +1545,13 @@ impl GgmlType for BlockQ5K { let d2 = d * sc as f32; let m2 = min * m as f32; for (ql, qh) in ql.iter().zip(qh) { - let to_add = if qh & u1 != 0 { 16 } else { 1 }; - y[ys_index] = d1 * ((ql & 0xF) + to_add) as f32 - m1; + let to_add = if qh & u1 != 0 { 16f32 } else { 0f32 }; + y[ys_index] = d1 * ((ql & 0xF) as f32 + to_add) - m1; ys_index += 1; } for (ql, qh) in ql.iter().zip(qh) { - let to_add = if qh & u2 != 0 { 16 } else { 1 }; - y[ys_index] = d2 * ((ql >> 4) + to_add) as f32 - m2; + let to_add = if qh & u2 != 0 { 16f32 } else { 0f32 }; + y[ys_index] = d2 * ((ql >> 4) as f32 + to_add) - m2; ys_index += 1; } is += 2; diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index e7a2ea7f..d31e77a7 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -407,7 +407,7 @@ fn quantize_q5k() -> Result<()> { let dst = round_vector(&dst); assert_eq!( [dst[0], dst[128], dst[256], dst[512], dst[800], dst[1023]], - [-0.499, -0.372, -0.249, 0.001, 0.279, 0.499] + [-0.5, -0.373, -0.25, 0.0, 0.279, 0.499] ); let (src_big, mut dst_big) = get_test_vector(128.0, 1024); |