diff options
Diffstat (limited to 'candle-core/examples/tensor-tools.rs')
-rw-r--r-- | candle-core/examples/tensor-tools.rs | 53 |
1 files changed, 53 insertions, 0 deletions
diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs index 2bc1fa2e..c3459004 100644 --- a/candle-core/examples/tensor-tools.rs +++ b/candle-core/examples/tensor-tools.rs @@ -218,12 +218,65 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R Ok(()) } +fn run_quantize_safetensors( + in_file: std::path::PathBuf, + out_file: std::path::PathBuf, + q: Quantization, +) -> Result<()> { + let mut out_file = std::fs::File::create(out_file)?; + let tensors = candle_core::safetensors::load(in_file, &Device::Cpu)?; + println!("tensors: {}", tensors.len()); + + let quantize_fn = match q { + Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>, + Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>, + Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>, + Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>, + Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>, + Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>, + Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>, + Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>, + Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>, + Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>, + Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>, + Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>, + Quantization::F16 => QTensor::quantize::<half::f16>, + Quantization::F32 => QTensor::quantize::<f32>, + }; + + let qtensors = tensors + .into_par_iter() + .map(|(name, tensor)| { + println!(" quantizing {name} {tensor:?}"); + let should_quantize = tensor.rank() == 2 && tensor.dim(0)? % 256 == 0; + let tensor = if should_quantize { + quantize_fn(&tensor)? + } else { + QTensor::quantize::<f32>(&tensor)? + }; + Ok((name, tensor)) + }) + .collect::<Result<Vec<_>>>()?; + let qtensors = qtensors + .iter() + .map(|(k, v)| (k.as_str(), v)) + .collect::<Vec<_>>(); + gguf_file::write(&mut out_file, &[], &qtensors)?; + Ok(()) +} + fn run_quantize( in_file: std::path::PathBuf, out_file: std::path::PathBuf, q: Quantization, qmode: QuantizationMode, ) -> Result<()> { + if let Some(extension) = in_file.extension() { + if extension == "safetensors" { + return run_quantize_safetensors(in_file, out_file, q); + } + } + // Open the out file early so as to fail directly on missing directories etc. let mut out_file = std::fs::File::create(out_file)?; let mut in_ = std::fs::File::open(&in_file)?; |