summaryrefslogtreecommitdiff
path: root/candle-core/examples/tensor-tools.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-03-30 11:34:33 +0100
committerGitHub <noreply@github.com>2024-03-30 11:34:33 +0100
commitefe4a0c84b55b60f7555a89ea7e0ba8d300104cd (patch)
tree41dfddebfd21b5bb637c8b6394ceab74394520f9 /candle-core/examples/tensor-tools.rs
parent665da304878326e267b178fa6e6d85424249126b (diff)
downloadcandle-efe4a0c84b55b60f7555a89ea7e0ba8d300104cd.tar.gz
candle-efe4a0c84b55b60f7555a89ea7e0ba8d300104cd.tar.bz2
candle-efe4a0c84b55b60f7555a89ea7e0ba8d300104cd.zip
Add a print command to tensor-tools. (#1967)
* Add a print command to tensor-tools. * Add some flags to tweak the formatting.
Diffstat (limited to 'candle-core/examples/tensor-tools.rs')
-rw-r--r--candle-core/examples/tensor-tools.rs124
1 files changed, 124 insertions, 0 deletions
diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs
index 1801ac58..5dc49cd8 100644
--- a/candle-core/examples/tensor-tools.rs
+++ b/candle-core/examples/tensor-tools.rs
@@ -117,6 +117,24 @@ enum Command {
verbose: bool,
},
+ Print {
+ file: std::path::PathBuf,
+
+ names: Vec<String>,
+
+ /// The file format to use, if unspecified infer from the file extension.
+ #[arg(long, value_enum)]
+ format: Option<Format>,
+
+ /// Print the whole content of each tensor.
+ #[arg(long)]
+ full: bool,
+
+ /// Line width for printing the tensors.
+ #[arg(long)]
+ line_width: Option<usize>,
+ },
+
Quantize {
/// The input file(s), in safetensors format.
in_file: Vec<std::path::PathBuf>,
@@ -150,6 +168,105 @@ struct Args {
command: Command,
}
+fn run_print(
+ file: &std::path::PathBuf,
+ names: Vec<String>,
+ format: Option<Format>,
+ full: bool,
+ line_width: Option<usize>,
+ device: &Device,
+) -> Result<()> {
+ if full {
+ candle_core::display::set_print_options_full();
+ }
+ if let Some(line_width) = line_width {
+ candle_core::display::set_line_width(line_width)
+ }
+ let format = match format {
+ Some(format) => format,
+ None => match Format::infer(file) {
+ Some(format) => format,
+ None => {
+ println!(
+ "{file:?}: cannot infer format from file extension, use the --format flag"
+ );
+ return Ok(());
+ }
+ },
+ };
+ match format {
+ Format::Npz => {
+ let tensors = candle_core::npy::NpzTensors::new(file)?;
+ for name in names.iter() {
+ println!("==== {name} ====");
+ match tensors.get(name)? {
+ Some(tensor) => println!("{tensor}"),
+ None => println!("not found"),
+ }
+ }
+ }
+ Format::Safetensors => {
+ use candle_core::safetensors::Load;
+ let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::new(file)? };
+ let tensors: std::collections::HashMap<_, _> = tensors.tensors().into_iter().collect();
+ for name in names.iter() {
+ println!("==== {name} ====");
+ match tensors.get(name) {
+ Some(tensor_view) => {
+ let tensor = tensor_view.load(device)?;
+ println!("{tensor}")
+ }
+ None => println!("not found"),
+ }
+ }
+ }
+ Format::Pth => {
+ let pth_file = candle_core::pickle::PthTensors::new(file, None)?;
+ for name in names.iter() {
+ println!("==== {name} ====");
+ match pth_file.get(name)? {
+ Some(tensor) => {
+ println!("{tensor}")
+ }
+ None => println!("not found"),
+ }
+ }
+ }
+ Format::Pickle => {
+ candle_core::bail!("pickle format is not supported for print")
+ }
+ Format::Ggml => {
+ let mut file = std::fs::File::open(file)?;
+ let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?;
+ for name in names.iter() {
+ println!("==== {name} ====");
+ match content.tensors.get(name) {
+ Some(tensor) => {
+ let tensor = tensor.dequantize(device)?;
+ println!("{tensor}")
+ }
+ None => println!("not found"),
+ }
+ }
+ }
+ Format::Gguf => {
+ let mut file = std::fs::File::open(file)?;
+ let content = gguf_file::Content::read(&mut file)?;
+ for name in names.iter() {
+ println!("==== {name} ====");
+ match content.tensor(&mut file, name, device) {
+ Ok(tensor) => {
+ let tensor = tensor.dequantize(device)?;
+ println!("{tensor}")
+ }
+ Err(_) => println!("not found"),
+ }
+ }
+ }
+ }
+ Ok(())
+}
+
fn run_ls(
file: &std::path::PathBuf,
format: Option<Format>,
@@ -377,6 +494,13 @@ fn main() -> anyhow::Result<()> {
run_ls(file, format.clone(), verbose, &device)?
}
}
+ Command::Print {
+ file,
+ names,
+ format,
+ full,
+ line_width,
+ } => run_print(&file, names, format, full, line_width, &device)?,
Command::Quantize {
in_file,
out_file,