diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-01-11 11:21:01 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-11 11:21:01 +0100 |
commit | 0fc95c9f0c426db0f32f7e853035fd3e8415c311 (patch) | |
tree | 407d2f266adca6d612ff3f6603764b5b7f352898 /candle-core | |
parent | 2480c5dbddec7cd086746df595be85fdf1407146 (diff) | |
download | candle-0fc95c9f0c426db0f32f7e853035fd3e8415c311.tar.gz candle-0fc95c9f0c426db0f32f7e853035fd3e8415c311.tar.bz2 candle-0fc95c9f0c426db0f32f7e853035fd3e8415c311.zip |
Add a dequantize command to tensor-tools. (#1565)
* Add a dequantize command to tensor-tools.
* Clippy fixes.
Diffstat (limited to 'candle-core')
-rw-r--r-- | candle-core/examples/tensor-tools.rs | 25 |
1 files changed, 24 insertions, 1 deletions
diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs index d06b30d1..337021aa 100644 --- a/candle-core/examples/tensor-tools.rs +++ b/candle-core/examples/tensor-tools.rs @@ -102,7 +102,7 @@ enum Command { }, Quantize { - /// The input file, in gguf format. + /// The input file(s), in safetensors format. in_file: Vec<std::path::PathBuf>, /// The output file, in gguf format. @@ -117,6 +117,15 @@ enum Command { #[arg(long, value_enum, default_value_t = QuantizationMode::Llama)] mode: QuantizationMode, }, + + Dequantize { + /// The input file, in gguf format. + in_file: std::path::PathBuf, + + /// The output file, in safetensors format. + #[arg(long)] + out_file: std::path::PathBuf, + }, } #[derive(Parser, Debug, Clone)] @@ -285,6 +294,19 @@ fn run_quantize_safetensors( Ok(()) } +fn run_dequantize(in_file: std::path::PathBuf, out_file: std::path::PathBuf) -> Result<()> { + let mut in_file = std::fs::File::open(in_file)?; + let content = gguf_file::Content::read(&mut in_file)?; + let mut tensors = std::collections::HashMap::new(); + for (tensor_name, _) in content.tensor_infos.iter() { + let tensor = content.tensor(&mut in_file, tensor_name)?; + let tensor = tensor.dequantize(&Device::Cpu)?; + tensors.insert(tensor_name.to_string(), tensor); + } + candle_core::safetensors::save(&tensors, out_file)?; + Ok(()) +} + fn run_quantize( in_files: &[std::path::PathBuf], out_file: std::path::PathBuf, @@ -379,6 +401,7 @@ fn main() -> anyhow::Result<()> { quantization, mode, } => run_quantize(&in_file, out_file, quantization, mode)?, + Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file)?, } Ok(()) } |