diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-22 15:57:46 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-22 15:57:46 +0100 |
commit | ec665acad7c882b2856635d0b7e99d4d3eee0a9e (patch) | |
tree | 4d0b85e65c07033cb370b9812d7adf217a00daee /candle-core/src/quantized/k_quants.rs | |
parent | cf27b9b6368d0af086e107d1ce890b2993825282 (diff) | |
download | candle-ec665acad7c882b2856635d0b7e99d4d3eee0a9e.tar.gz candle-ec665acad7c882b2856635d0b7e99d4d3eee0a9e.tar.bz2 candle-ec665acad7c882b2856635d0b7e99d4d3eee0a9e.zip |
Revert "Avoid some mut in quantized functions. (#550)" (#552)
This reverts commit cf27b9b6368d0af086e107d1ce890b2993825282.
Diffstat (limited to 'candle-core/src/quantized/k_quants.rs')
-rw-r--r-- | candle-core/src/quantized/k_quants.rs | 41 |
1 files changed, 25 insertions, 16 deletions
diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index 3e45bc6d..bfc471a3 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -503,7 +503,8 @@ impl GgmlType for BlockQ2K { } let dm = block.dmin.to_f32() * (block.scales[j] >> 4) as f32; for ii in 0..16 { - let ll = nearest_int((x[16 * j + ii] + dm) / d).clamp(0, 3); + let mut ll = nearest_int((x[16 * j + ii] + dm) / d); + ll = ll.clamp(0, 3); big_l[16 * j + ii] = ll as u8; } } @@ -586,14 +587,14 @@ impl GgmlType for BlockQ3K { if max_scale != 0.0 { let iscale = -32.0 / max_scale; for (j, scale) in scales.iter().enumerate() { - let l_val = nearest_int(iscale * scale); - let l_val = l_val.clamp(-32, 31) + 32; + let mut l_val = nearest_int(iscale * scale); + l_val = l_val.clamp(-32, 31) + 32; if j < 8 { block.scales[j] = (l_val & 0xF) as u8; } else { block.scales[j - 8] |= ((l_val & 0xF) << 4) as u8; } - let l_val = l_val >> 4; + l_val >>= 4; block.scales[j % 4 + 8] |= (l_val << (2 * (j / 4))) as u8; } block.d = f16::from_f32(1.0 / iscale); @@ -613,8 +614,9 @@ impl GgmlType for BlockQ3K { let d = block.d.to_f32() * sc as f32; if d != 0.0 { for ii in 0..16 { - let l_val = nearest_int(x[16 * j + ii] / d); - l[16 * j + ii] = (l_val.clamp(-4, 3) + 4) as i8; + let mut l_val = nearest_int(x[16 * j + ii] / d); + l_val = l_val.clamp(-4, 3); + l[16 * j + ii] = (l_val + 4) as i8; } } } @@ -700,7 +702,7 @@ impl GgmlType for BlockQ3K { // 16 block finished => advance scale index is += 1; } - // 32 block finished => increase shift and m + //32 block finished => increase shift and m shift += 2; m <<= 1; } @@ -741,8 +743,10 @@ impl GgmlType for BlockQ4K { let inv_min = if max_min > 0.0 { 63.0 / max_min } else { 0.0 }; for j in 0..QK_K / 32 { - let ls = nearest_int(inv_scale * scales[j]).min(63) as u8; - let lm = nearest_int(inv_min * mins[j]).min(63) as u8; + let mut ls = nearest_int(inv_scale * scales[j]) as u8; + let mut lm = nearest_int(inv_min * mins[j]) as u8; + ls = std::cmp::min(63, ls); + lm = std::cmp::min(63, lm); if j < 4 { block.scales[j] = ls; block.scales[j + 4] = lm; @@ -764,8 +768,9 @@ impl GgmlType for BlockQ4K { if d != 0.0 { let dm = block.dmin.to_f32() * m as f32; for ii in 0..32 { - let l_val = nearest_int((x[32 * j + ii] + dm) / d); - l[32 * j + ii] = l_val.clamp(0, 15) as u8; + let mut l_val = nearest_int((x[32 * j + ii] + dm) / d); + l_val = l_val.clamp(0, 15); + l[32 * j + ii] = l_val as u8; } } } @@ -786,10 +791,10 @@ impl GgmlType for BlockQ4K { let d = block.d.to_f32(); let min = block.dmin.to_f32(); let q = &block.qs; + let mut is = 0; let mut ys_index = 0; for j in (0..QK_K).step_by(64) { - let is = j * 2; let q = &q[j / 2..j / 2 + 32]; let (sc, m) = get_scale_min_k4(is, &block.scales); let d1 = d * sc as f32; @@ -805,6 +810,7 @@ impl GgmlType for BlockQ4K { y[ys_index] = d2 * (q >> 4) as f32 - m2; ys_index += 1; } + is += 2; } } Ok(()) @@ -842,8 +848,10 @@ impl GgmlType for BlockQ5K { }; let inv_min = if max_min > 0.0 { 63.0 / max_min } else { 0.0 }; for j in 0..QK_K / 32 { - let ls = nearest_int(inv_scale * scales[j]).min(63) as u8; - let lm = nearest_int(inv_min * mins[j]).min(63) as u8; + let mut ls = nearest_int(inv_scale * scales[j]) as u8; + let mut lm = nearest_int(inv_min * mins[j]) as u8; + ls = ls.min(63); + lm = lm.min(63); if j < 4 { block.scales[j] = ls; block.scales[j + 4] = lm; @@ -865,8 +873,9 @@ impl GgmlType for BlockQ5K { } let dm = block.dmin.to_f32() * m as f32; for ii in 0..32 { - let ll = nearest_int((x[32 * j + ii] + dm) / d); - l[32 * j + ii] = ll.clamp(0, 31) as u8; + let mut ll = nearest_int((x[32 * j + ii] + dm) / d); + ll = ll.min(31).max(0); + l[32 * j + ii] = ll as u8; } } |