summaryrefslogtreecommitdiff
path: root/candle-core
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-01-11 11:21:01 +0100
committerGitHub <noreply@github.com>2024-01-11 11:21:01 +0100
commit0fc95c9f0c426db0f32f7e853035fd3e8415c311 (patch)
tree407d2f266adca6d612ff3f6603764b5b7f352898 /candle-core
parent2480c5dbddec7cd086746df595be85fdf1407146 (diff)
downloadcandle-0fc95c9f0c426db0f32f7e853035fd3e8415c311.tar.gz
candle-0fc95c9f0c426db0f32f7e853035fd3e8415c311.tar.bz2
candle-0fc95c9f0c426db0f32f7e853035fd3e8415c311.zip
Add a dequantize command to tensor-tools. (#1565)
* Add a dequantize command to tensor-tools. * Clippy fixes.
Diffstat (limited to 'candle-core')
-rw-r--r--candle-core/examples/tensor-tools.rs25
1 files changed, 24 insertions, 1 deletions
diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs
index d06b30d1..337021aa 100644
--- a/candle-core/examples/tensor-tools.rs
+++ b/candle-core/examples/tensor-tools.rs
@@ -102,7 +102,7 @@ enum Command {
},
Quantize {
- /// The input file, in gguf format.
+ /// The input file(s), in safetensors format.
in_file: Vec<std::path::PathBuf>,
/// The output file, in gguf format.
@@ -117,6 +117,15 @@ enum Command {
#[arg(long, value_enum, default_value_t = QuantizationMode::Llama)]
mode: QuantizationMode,
},
+
+ Dequantize {
+ /// The input file, in gguf format.
+ in_file: std::path::PathBuf,
+
+ /// The output file, in safetensors format.
+ #[arg(long)]
+ out_file: std::path::PathBuf,
+ },
}
#[derive(Parser, Debug, Clone)]
@@ -285,6 +294,19 @@ fn run_quantize_safetensors(
Ok(())
}
+fn run_dequantize(in_file: std::path::PathBuf, out_file: std::path::PathBuf) -> Result<()> {
+ let mut in_file = std::fs::File::open(in_file)?;
+ let content = gguf_file::Content::read(&mut in_file)?;
+ let mut tensors = std::collections::HashMap::new();
+ for (tensor_name, _) in content.tensor_infos.iter() {
+ let tensor = content.tensor(&mut in_file, tensor_name)?;
+ let tensor = tensor.dequantize(&Device::Cpu)?;
+ tensors.insert(tensor_name.to_string(), tensor);
+ }
+ candle_core::safetensors::save(&tensors, out_file)?;
+ Ok(())
+}
+
fn run_quantize(
in_files: &[std::path::PathBuf],
out_file: std::path::PathBuf,
@@ -379,6 +401,7 @@ fn main() -> anyhow::Result<()> {
quantization,
mode,
} => run_quantize(&in_file, out_file, quantization, mode)?,
+ Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file)?,
}
Ok(())
}