diff options
Diffstat (limited to 'candle-core/examples/tensor-tools.rs')
-rw-r--r-- | candle-core/examples/tensor-tools.rs | 24 |
1 files changed, 15 insertions, 9 deletions
diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs index e2d12aa5..19c9d0ca 100644 --- a/candle-core/examples/tensor-tools.rs +++ b/candle-core/examples/tensor-tools.rs @@ -6,7 +6,7 @@ enum Format { Safetensors, Npz, Ggml, - PyTorch, + Pth, Pickle, } @@ -16,9 +16,10 @@ impl Format { .extension() .and_then(|e| e.to_str()) .and_then(|e| match e { + // We don't infer any format for .bin as it can be used for ggml or pytorch. "safetensors" | "safetensor" => Some(Self::Safetensors), "npz" => Some(Self::Npz), - "pth" | "pt" => Some(Self::PyTorch), + "pth" | "pt" => Some(Self::Pth), "ggml" => Some(Self::Ggml), _ => None, }) @@ -29,18 +30,19 @@ impl Format { enum Command { Ls { files: Vec<std::path::PathBuf>, + /// The file format to use, if unspecified infer from the file extension. #[arg(long, value_enum)] format: Option<Format>, + + /// Enable verbose mode. + #[arg(short, long)] + verbose: bool, }, } #[derive(Parser, Debug, Clone)] struct Args { - /// Enable verbose mode. - #[arg(short, long)] - verbose: bool, - #[command(subcommand)] command: Command, } @@ -86,7 +88,7 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R println!("{name}: [{shape:?}; {dtype}]") } } - Format::PyTorch => { + Format::Pth => { let mut tensors = candle_core::pickle::read_pth_tensor_info(file)?; tensors.sort_by(|a, b| a.name.cmp(&b.name)); for tensor_info in tensors.iter() { @@ -126,13 +128,17 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R fn main() -> anyhow::Result<()> { let args = Args::parse(); match args.command { - Command::Ls { files, format } => { + Command::Ls { + files, + format, + verbose, + } => { let multiple_files = files.len() > 1; for file in files.iter() { if multiple_files { println!("--- {file:?} ---"); } - run_ls(file, format.clone(), args.verbose)? + run_ls(file, format.clone(), verbose)? } } } |