diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-19 16:50:26 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-19 16:50:26 +0100 |
commit | 551409092ecacd0400982eaaa33084c83ef54c57 (patch) | |
tree | a2803b8f4fd0f268c38541f66d93e887294c420a /candle-core/examples | |
parent | 6431140250a2185af6f053da3bea6ea68c937ef3 (diff) | |
download | candle-551409092ecacd0400982eaaa33084c83ef54c57.tar.gz candle-551409092ecacd0400982eaaa33084c83ef54c57.tar.bz2 candle-551409092ecacd0400982eaaa33084c83ef54c57.zip |
Small tweaks to tensor-tools. (#517)
Diffstat (limited to 'candle-core/examples')
-rw-r--r-- | candle-core/examples/tensor-tools.rs | 24 |
1 files changed, 15 insertions, 9 deletions
diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs index e2d12aa5..19c9d0ca 100644 --- a/candle-core/examples/tensor-tools.rs +++ b/candle-core/examples/tensor-tools.rs @@ -6,7 +6,7 @@ enum Format { Safetensors, Npz, Ggml, - PyTorch, + Pth, Pickle, } @@ -16,9 +16,10 @@ impl Format { .extension() .and_then(|e| e.to_str()) .and_then(|e| match e { + // We don't infer any format for .bin as it can be used for ggml or pytorch. "safetensors" | "safetensor" => Some(Self::Safetensors), "npz" => Some(Self::Npz), - "pth" | "pt" => Some(Self::PyTorch), + "pth" | "pt" => Some(Self::Pth), "ggml" => Some(Self::Ggml), _ => None, }) @@ -29,18 +30,19 @@ impl Format { enum Command { Ls { files: Vec<std::path::PathBuf>, + /// The file format to use, if unspecified infer from the file extension. #[arg(long, value_enum)] format: Option<Format>, + + /// Enable verbose mode. + #[arg(short, long)] + verbose: bool, }, } #[derive(Parser, Debug, Clone)] struct Args { - /// Enable verbose mode. - #[arg(short, long)] - verbose: bool, - #[command(subcommand)] command: Command, } @@ -86,7 +88,7 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R println!("{name}: [{shape:?}; {dtype}]") } } - Format::PyTorch => { + Format::Pth => { let mut tensors = candle_core::pickle::read_pth_tensor_info(file)?; tensors.sort_by(|a, b| a.name.cmp(&b.name)); for tensor_info in tensors.iter() { @@ -126,13 +128,17 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R fn main() -> anyhow::Result<()> { let args = Args::parse(); match args.command { - Command::Ls { files, format } => { + Command::Ls { + files, + format, + verbose, + } => { let multiple_files = files.len() > 1; for file in files.iter() { if multiple_files { println!("--- {file:?} ---"); } - run_ls(file, format.clone(), args.verbose)? + run_ls(file, format.clone(), verbose)? } } } |