summaryrefslogtreecommitdiff
path: root/candle-core/examples/tensor-tools.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/examples/tensor-tools.rs')
-rw-r--r--candle-core/examples/tensor-tools.rs19
1 files changed, 17 insertions, 2 deletions
diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs
index c3459004..c0d5a334 100644
--- a/candle-core/examples/tensor-tools.rs
+++ b/candle-core/examples/tensor-tools.rs
@@ -243,12 +243,27 @@ fn run_quantize_safetensors(
Quantization::F16 => QTensor::quantize::<half::f16>,
Quantization::F32 => QTensor::quantize::<f32>,
};
+ let block_size = match q {
+ Quantization::Q4_0 => k_quants::QK4_0,
+ Quantization::Q4_1 => k_quants::QK4_1,
+ Quantization::Q5_0 => k_quants::QK5_0,
+ Quantization::Q5_1 => k_quants::QK5_1,
+ Quantization::Q8_0 => k_quants::QK8_0,
+ Quantization::Q8_1 => k_quants::QK8_1,
+ Quantization::Q2k
+ | Quantization::Q3k
+ | Quantization::Q4k
+ | Quantization::Q5k
+ | Quantization::Q6k
+ | Quantization::Q8k => k_quants::QK_K,
+ Quantization::F16 | Quantization::F32 => 1,
+ };
let qtensors = tensors
.into_par_iter()
.map(|(name, tensor)| {
- println!(" quantizing {name} {tensor:?}");
- let should_quantize = tensor.rank() == 2 && tensor.dim(0)? % 256 == 0;
+ let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0;
+ println!(" quantizing {name} {tensor:?} {should_quantize}");
let tensor = if should_quantize {
quantize_fn(&tensor)?
} else {