summaryrefslogtreecommitdiff
path: root/candle-core/examples/tensor-tools.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/examples/tensor-tools.rs')
-rw-r--r--candle-core/examples/tensor-tools.rs25
1 files changed, 24 insertions, 1 deletions
diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs
index d06b30d1..337021aa 100644
--- a/candle-core/examples/tensor-tools.rs
+++ b/candle-core/examples/tensor-tools.rs
@@ -102,7 +102,7 @@ enum Command {
},
Quantize {
- /// The input file, in gguf format.
+ /// The input file(s), in safetensors format.
in_file: Vec<std::path::PathBuf>,
/// The output file, in gguf format.
@@ -117,6 +117,15 @@ enum Command {
#[arg(long, value_enum, default_value_t = QuantizationMode::Llama)]
mode: QuantizationMode,
},
+
+ Dequantize {
+ /// The input file, in gguf format.
+ in_file: std::path::PathBuf,
+
+ /// The output file, in safetensors format.
+ #[arg(long)]
+ out_file: std::path::PathBuf,
+ },
}
#[derive(Parser, Debug, Clone)]
@@ -285,6 +294,19 @@ fn run_quantize_safetensors(
Ok(())
}
+fn run_dequantize(in_file: std::path::PathBuf, out_file: std::path::PathBuf) -> Result<()> {
+ let mut in_file = std::fs::File::open(in_file)?;
+ let content = gguf_file::Content::read(&mut in_file)?;
+ let mut tensors = std::collections::HashMap::new();
+ for (tensor_name, _) in content.tensor_infos.iter() {
+ let tensor = content.tensor(&mut in_file, tensor_name)?;
+ let tensor = tensor.dequantize(&Device::Cpu)?;
+ tensors.insert(tensor_name.to_string(), tensor);
+ }
+ candle_core::safetensors::save(&tensors, out_file)?;
+ Ok(())
+}
+
fn run_quantize(
in_files: &[std::path::PathBuf],
out_file: std::path::PathBuf,
@@ -379,6 +401,7 @@ fn main() -> anyhow::Result<()> {
quantization,
mode,
} => run_quantize(&in_file, out_file, quantization, mode)?,
+ Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file)?,
}
Ok(())
}