diff options
Diffstat (limited to 'candle-core/examples/tensor-tools.rs')
-rw-r--r-- | candle-core/examples/tensor-tools.rs | 124 |
1 files changed, 124 insertions, 0 deletions
diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs index 1801ac58..5dc49cd8 100644 --- a/candle-core/examples/tensor-tools.rs +++ b/candle-core/examples/tensor-tools.rs @@ -117,6 +117,24 @@ enum Command { verbose: bool, }, + Print { + file: std::path::PathBuf, + + names: Vec<String>, + + /// The file format to use, if unspecified infer from the file extension. + #[arg(long, value_enum)] + format: Option<Format>, + + /// Print the whole content of each tensor. + #[arg(long)] + full: bool, + + /// Line width for printing the tensors. + #[arg(long)] + line_width: Option<usize>, + }, + Quantize { /// The input file(s), in safetensors format. in_file: Vec<std::path::PathBuf>, @@ -150,6 +168,105 @@ struct Args { command: Command, } +fn run_print( + file: &std::path::PathBuf, + names: Vec<String>, + format: Option<Format>, + full: bool, + line_width: Option<usize>, + device: &Device, +) -> Result<()> { + if full { + candle_core::display::set_print_options_full(); + } + if let Some(line_width) = line_width { + candle_core::display::set_line_width(line_width) + } + let format = match format { + Some(format) => format, + None => match Format::infer(file) { + Some(format) => format, + None => { + println!( + "{file:?}: cannot infer format from file extension, use the --format flag" + ); + return Ok(()); + } + }, + }; + match format { + Format::Npz => { + let tensors = candle_core::npy::NpzTensors::new(file)?; + for name in names.iter() { + println!("==== {name} ===="); + match tensors.get(name)? { + Some(tensor) => println!("{tensor}"), + None => println!("not found"), + } + } + } + Format::Safetensors => { + use candle_core::safetensors::Load; + let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::new(file)? }; + let tensors: std::collections::HashMap<_, _> = tensors.tensors().into_iter().collect(); + for name in names.iter() { + println!("==== {name} ===="); + match tensors.get(name) { + Some(tensor_view) => { + let tensor = tensor_view.load(device)?; + println!("{tensor}") + } + None => println!("not found"), + } + } + } + Format::Pth => { + let pth_file = candle_core::pickle::PthTensors::new(file, None)?; + for name in names.iter() { + println!("==== {name} ===="); + match pth_file.get(name)? { + Some(tensor) => { + println!("{tensor}") + } + None => println!("not found"), + } + } + } + Format::Pickle => { + candle_core::bail!("pickle format is not supported for print") + } + Format::Ggml => { + let mut file = std::fs::File::open(file)?; + let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?; + for name in names.iter() { + println!("==== {name} ===="); + match content.tensors.get(name) { + Some(tensor) => { + let tensor = tensor.dequantize(device)?; + println!("{tensor}") + } + None => println!("not found"), + } + } + } + Format::Gguf => { + let mut file = std::fs::File::open(file)?; + let content = gguf_file::Content::read(&mut file)?; + for name in names.iter() { + println!("==== {name} ===="); + match content.tensor(&mut file, name, device) { + Ok(tensor) => { + let tensor = tensor.dequantize(device)?; + println!("{tensor}") + } + Err(_) => println!("not found"), + } + } + } + } + Ok(()) +} + fn run_ls( file: &std::path::PathBuf, format: Option<Format>, @@ -377,6 +494,13 @@ fn main() -> anyhow::Result<()> { run_ls(file, format.clone(), verbose, &device)? } } + Command::Print { + file, + names, + format, + full, + line_width, + } => run_print(&file, names, format, full, line_width, &device)?, Command::Quantize { in_file, out_file, |