summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-22 18:51:20 +0100
committerGitHub <noreply@github.com>2023-08-22 18:51:20 +0100
commit07067b01dce3c63b45fe4bdeb8d972f279e88b45 (patch)
tree137fd060730d328d7fdde183e54821a859d7b647
parentcc22d4db20d5623ae7fde294eefc6ff3df1b31e8 (diff)
downloadcandle-07067b01dce3c63b45fe4bdeb8d972f279e88b45.tar.gz
candle-07067b01dce3c63b45fe4bdeb8d972f279e88b45.tar.bz2
candle-07067b01dce3c63b45fe4bdeb8d972f279e88b45.zip
Avoid some mutable variables (take 2). (#554)
* Avoid some mutable variables (take 2). * Fix.
-rw-r--r--candle-core/src/quantized/k_quants.rs38
-rw-r--r--candle-core/src/quantized/utils.rs28
2 files changed, 29 insertions, 37 deletions
diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs
index bfc471a3..e2f4ab74 100644
--- a/candle-core/src/quantized/k_quants.rs
+++ b/candle-core/src/quantized/k_quants.rs
@@ -503,8 +503,7 @@ impl GgmlType for BlockQ2K {
}
let dm = block.dmin.to_f32() * (block.scales[j] >> 4) as f32;
for ii in 0..16 {
- let mut ll = nearest_int((x[16 * j + ii] + dm) / d);
- ll = ll.clamp(0, 3);
+ let ll = nearest_int((x[16 * j + ii] + dm) / d).clamp(0, 3);
big_l[16 * j + ii] = ll as u8;
}
}
@@ -587,14 +586,14 @@ impl GgmlType for BlockQ3K {
if max_scale != 0.0 {
let iscale = -32.0 / max_scale;
for (j, scale) in scales.iter().enumerate() {
- let mut l_val = nearest_int(iscale * scale);
- l_val = l_val.clamp(-32, 31) + 32;
+ let l_val = nearest_int(iscale * scale);
+ let l_val = l_val.clamp(-32, 31) + 32;
if j < 8 {
block.scales[j] = (l_val & 0xF) as u8;
} else {
block.scales[j - 8] |= ((l_val & 0xF) << 4) as u8;
}
- l_val >>= 4;
+ let l_val = l_val >> 4;
block.scales[j % 4 + 8] |= (l_val << (2 * (j / 4))) as u8;
}
block.d = f16::from_f32(1.0 / iscale);
@@ -614,9 +613,8 @@ impl GgmlType for BlockQ3K {
let d = block.d.to_f32() * sc as f32;
if d != 0.0 {
for ii in 0..16 {
- let mut l_val = nearest_int(x[16 * j + ii] / d);
- l_val = l_val.clamp(-4, 3);
- l[16 * j + ii] = (l_val + 4) as i8;
+ let l_val = nearest_int(x[16 * j + ii] / d);
+ l[16 * j + ii] = (l_val.clamp(-4, 3) + 4) as i8;
}
}
}
@@ -702,7 +700,7 @@ impl GgmlType for BlockQ3K {
// 16 block finished => advance scale index
is += 1;
}
- //32 block finished => increase shift and m
+ // 32 block finished => increase shift and m
shift += 2;
m <<= 1;
}
@@ -743,10 +741,8 @@ impl GgmlType for BlockQ4K {
let inv_min = if max_min > 0.0 { 63.0 / max_min } else { 0.0 };
for j in 0..QK_K / 32 {
- let mut ls = nearest_int(inv_scale * scales[j]) as u8;
- let mut lm = nearest_int(inv_min * mins[j]) as u8;
- ls = std::cmp::min(63, ls);
- lm = std::cmp::min(63, lm);
+ let ls = nearest_int(inv_scale * scales[j]).min(63) as u8;
+ let lm = nearest_int(inv_min * mins[j]).min(63) as u8;
if j < 4 {
block.scales[j] = ls;
block.scales[j + 4] = lm;
@@ -768,9 +764,8 @@ impl GgmlType for BlockQ4K {
if d != 0.0 {
let dm = block.dmin.to_f32() * m as f32;
for ii in 0..32 {
- let mut l_val = nearest_int((x[32 * j + ii] + dm) / d);
- l_val = l_val.clamp(0, 15);
- l[32 * j + ii] = l_val as u8;
+ let l_val = nearest_int((x[32 * j + ii] + dm) / d);
+ l[32 * j + ii] = l_val.clamp(0, 15) as u8;
}
}
}
@@ -848,10 +843,8 @@ impl GgmlType for BlockQ5K {
};
let inv_min = if max_min > 0.0 { 63.0 / max_min } else { 0.0 };
for j in 0..QK_K / 32 {
- let mut ls = nearest_int(inv_scale * scales[j]) as u8;
- let mut lm = nearest_int(inv_min * mins[j]) as u8;
- ls = ls.min(63);
- lm = lm.min(63);
+ let ls = nearest_int(inv_scale * scales[j]).min(63) as u8;
+ let lm = nearest_int(inv_min * mins[j]).min(63) as u8;
if j < 4 {
block.scales[j] = ls;
block.scales[j + 4] = lm;
@@ -873,9 +866,8 @@ impl GgmlType for BlockQ5K {
}
let dm = block.dmin.to_f32() * m as f32;
for ii in 0..32 {
- let mut ll = nearest_int((x[32 * j + ii] + dm) / d);
- ll = ll.min(31).max(0);
- l[32 * j + ii] = ll as u8;
+ let ll = nearest_int((x[32 * j + ii] + dm) / d);
+ l[32 * j + ii] = ll.clamp(0, 31) as u8;
}
}
diff --git a/candle-core/src/quantized/utils.rs b/candle-core/src/quantized/utils.rs
index fded9d61..edbffa35 100644
--- a/candle-core/src/quantized/utils.rs
+++ b/candle-core/src/quantized/utils.rs
@@ -4,7 +4,9 @@ pub(super) fn nearest_int(v: f32) -> i32 {
v.round() as i32
}
-/// Validates that the input and output are the right size and returns an iterator which maps each input region `xs` to its corresponding output block in `ys`. Each output region is guaranteed to be `T::BLCK_SIZE` long.
+/// Validates that the input and output are the right size and returns an iterator which maps each
+/// input region `xs` to its corresponding output block in `ys`. Each output region is guaranteed
+/// to be `T::BLCK_SIZE` long.
pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>(
xs: &'b [f32],
ys: &'a mut [T],
@@ -23,7 +25,9 @@ pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>(
Ok(ys.iter_mut().zip(xs.chunks_exact(block_size)).collect())
}
-/// Validates that the input and output are the right size and returns an iterator which maps each input block `xs` to its corresponding output region in `ys`. Each output region is guaranteed to be `T::BLCK_SIZE` long.
+/// Validates that the input and output are the right size and returns an iterator which maps each
+/// input block `xs` to its corresponding output region in `ys`. Each output region is guaranteed
+/// to be `T::BLCK_SIZE` long.
pub(super) fn group_for_dequantization<'a, 'b, T: super::k_quants::GgmlType>(
xs: &'a [T],
ys: &'b mut [f32],
@@ -174,7 +178,7 @@ pub(super) unsafe fn make_qx_quants(
for i in 0..n {
let x = *x.add(i);
let l = nearest_int(iscale * x);
- let l = i32::max(-nmax, i32::min(nmax - 1, l));
+ let l = l.clamp(-nmax, nmax - 1);
let w = if weight_type == 1 { x * x } else { 1. };
let l = l as f32;
sumlx += w * x * l;
@@ -198,7 +202,7 @@ pub(super) fn make_qkx1_quants(nmax: i32, ntry: usize, x: &[f32]) -> (f32, f32)
let n = x.len();
let mut l = vec![0; n];
// Get min/max
- let mut min = *x
+ let min = *x
.iter()
.take(n)
.min_by(|a, b| a.total_cmp(b))
@@ -211,9 +215,7 @@ pub(super) fn make_qkx1_quants(nmax: i32, ntry: usize, x: &[f32]) -> (f32, f32)
}
// Ensure min <= 0.0
- if min > 0.0 {
- min = 0.0;
- }
+ let mut min = min.min(0.);
// Compute scale and inverse scale
let mut iscale = nmax as f32 / (max - min);
@@ -225,8 +227,7 @@ pub(super) fn make_qkx1_quants(nmax: i32, ntry: usize, x: &[f32]) -> (f32, f32)
let mut did_change = false;
for (i, value) in x.iter().enumerate().take(n) {
- let mut li = nearest_int(iscale * (value - min));
- li = li.clamp(0, nmax);
+ let li = nearest_int(iscale * (value - min)).clamp(0, nmax);
let clamped_li = li as u8;
if clamped_li != l[i] {
l[i] = clamped_li;
@@ -280,8 +281,8 @@ pub(super) fn make_q3_quants(x: &[f32], nmax: i32, do_rmse: bool) -> f32 {
let mut sumlx = 0.0;
let mut suml2 = 0.0;
for i in 0..n {
- let mut li = (iscale * x[i]).round() as i32;
- li = li.clamp(-nmax, nmax - 1);
+ let li = (iscale * x[i]).round() as i32;
+ let li = li.clamp(-nmax, nmax - 1);
l[i] = li as i8;
let w = x[i] * x[i];
sumlx += w * x[i] * li as f32;
@@ -318,9 +319,8 @@ pub(super) fn make_q3_quants(x: &[f32], nmax: i32, do_rmse: bool) -> f32 {
return sumlx / suml2;
}
for i in 0..n {
- let mut li = (iscale * x[i]).round() as i32;
- li = li.clamp(-nmax, nmax - 1);
- l[i] = (li + nmax) as i8;
+ let li = (iscale * x[i]).round() as i32;
+ l[i] = (li.clamp(-nmax, nmax - 1) + nmax) as i8;
}
1.0 / iscale
}