summaryrefslogtreecommitdiff
path: root/candle-core/examples/tensor-tools.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/examples/tensor-tools.rs')
-rw-r--r--candle-core/examples/tensor-tools.rs53
1 files changed, 53 insertions, 0 deletions
diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs
index 2bc1fa2e..c3459004 100644
--- a/candle-core/examples/tensor-tools.rs
+++ b/candle-core/examples/tensor-tools.rs
@@ -218,12 +218,65 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
Ok(())
}
+fn run_quantize_safetensors(
+ in_file: std::path::PathBuf,
+ out_file: std::path::PathBuf,
+ q: Quantization,
+) -> Result<()> {
+ let mut out_file = std::fs::File::create(out_file)?;
+ let tensors = candle_core::safetensors::load(in_file, &Device::Cpu)?;
+ println!("tensors: {}", tensors.len());
+
+ let quantize_fn = match q {
+ Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
+ Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
+ Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
+ Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
+ Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
+ Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
+ Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
+ Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
+ Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
+ Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
+ Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
+ Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
+ Quantization::F16 => QTensor::quantize::<half::f16>,
+ Quantization::F32 => QTensor::quantize::<f32>,
+ };
+
+ let qtensors = tensors
+ .into_par_iter()
+ .map(|(name, tensor)| {
+ println!(" quantizing {name} {tensor:?}");
+ let should_quantize = tensor.rank() == 2 && tensor.dim(0)? % 256 == 0;
+ let tensor = if should_quantize {
+ quantize_fn(&tensor)?
+ } else {
+ QTensor::quantize::<f32>(&tensor)?
+ };
+ Ok((name, tensor))
+ })
+ .collect::<Result<Vec<_>>>()?;
+ let qtensors = qtensors
+ .iter()
+ .map(|(k, v)| (k.as_str(), v))
+ .collect::<Vec<_>>();
+ gguf_file::write(&mut out_file, &[], &qtensors)?;
+ Ok(())
+}
+
fn run_quantize(
in_file: std::path::PathBuf,
out_file: std::path::PathBuf,
q: Quantization,
qmode: QuantizationMode,
) -> Result<()> {
+ if let Some(extension) = in_file.extension() {
+ if extension == "safetensors" {
+ return run_quantize_safetensors(in_file, out_file, q);
+ }
+ }
+
// Open the out file early so as to fail directly on missing directories etc.
let mut out_file = std::fs::File::create(out_file)?;
let mut in_ = std::fs::File::open(&in_file)?;