diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-19 16:50:26 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-19 16:50:26 +0100 |
commit | 551409092ecacd0400982eaaa33084c83ef54c57 (patch) | |
tree | a2803b8f4fd0f268c38541f66d93e887294c420a | |
parent | 6431140250a2185af6f053da3bea6ea68c937ef3 (diff) | |
download | candle-551409092ecacd0400982eaaa33084c83ef54c57.tar.gz candle-551409092ecacd0400982eaaa33084c83ef54c57.tar.bz2 candle-551409092ecacd0400982eaaa33084c83ef54c57.zip |
Small tweaks to tensor-tools. (#517)
-rw-r--r-- | candle-core/examples/tensor-tools.rs | 24 | ||||
-rw-r--r-- | candle-core/src/pickle.rs | 9 |
2 files changed, 21 insertions, 12 deletions
diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs index e2d12aa5..19c9d0ca 100644 --- a/candle-core/examples/tensor-tools.rs +++ b/candle-core/examples/tensor-tools.rs @@ -6,7 +6,7 @@ enum Format { Safetensors, Npz, Ggml, - PyTorch, + Pth, Pickle, } @@ -16,9 +16,10 @@ impl Format { .extension() .and_then(|e| e.to_str()) .and_then(|e| match e { + // We don't infer any format for .bin as it can be used for ggml or pytorch. "safetensors" | "safetensor" => Some(Self::Safetensors), "npz" => Some(Self::Npz), - "pth" | "pt" => Some(Self::PyTorch), + "pth" | "pt" => Some(Self::Pth), "ggml" => Some(Self::Ggml), _ => None, }) @@ -29,18 +30,19 @@ impl Format { enum Command { Ls { files: Vec<std::path::PathBuf>, + /// The file format to use, if unspecified infer from the file extension. #[arg(long, value_enum)] format: Option<Format>, + + /// Enable verbose mode. + #[arg(short, long)] + verbose: bool, }, } #[derive(Parser, Debug, Clone)] struct Args { - /// Enable verbose mode. - #[arg(short, long)] - verbose: bool, - #[command(subcommand)] command: Command, } @@ -86,7 +88,7 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R println!("{name}: [{shape:?}; {dtype}]") } } - Format::PyTorch => { + Format::Pth => { let mut tensors = candle_core::pickle::read_pth_tensor_info(file)?; tensors.sort_by(|a, b| a.name.cmp(&b.name)); for tensor_info in tensors.iter() { @@ -126,13 +128,17 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R fn main() -> anyhow::Result<()> { let args = Args::parse(); match args.command { - Command::Ls { files, format } => { + Command::Ls { + files, + format, + verbose, + } => { let multiple_files = files.len() > 1; for file in files.iter() { if multiple_files { println!("--- {file:?} ---"); } - run_ls(file, format.clone(), args.verbose)? + run_ls(file, format.clone(), verbose)? } } } diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs index f14a5046..e913935c 100644 --- a/candle-core/src/pickle.rs +++ b/candle-core/src/pickle.rs @@ -490,13 +490,14 @@ impl From<Object> for E { // https://github.com/pytorch/pytorch/blob/4eac43d046ded0f0a5a5fa8db03eb40f45bf656e/torch/_utils.py#L198 // Arguments: storage, storage_offset, size, stride, requires_grad, backward_hooks -fn rebuild_args(args: Object) -> Result<(Layout, DType, String)> { +fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> { let mut args = args.tuple()?; let stride = Vec::<usize>::try_from(args.remove(3))?; let size = Vec::<usize>::try_from(args.remove(2))?; let offset = args.remove(1).int()? as usize; let storage = args.remove(0).persistent_load()?; let mut storage = storage.tuple()?; + let storage_size = storage.remove(4).int()? as usize; let path = storage.remove(2).unicode()?; let (_module_name, class_name) = storage.remove(1).class()?; let dtype = match class_name.as_str() { @@ -510,7 +511,7 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String)> { } }; let layout = Layout::new(crate::Shape::from(size), stride, offset); - Ok((layout, dtype, path)) + Ok((layout, dtype, path, storage_size)) } #[derive(Debug, Clone)] @@ -519,6 +520,7 @@ pub struct TensorInfo { pub dtype: DType, pub layout: Layout, pub path: String, + pub storage_size: usize, } pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(file: P) -> Result<Vec<TensorInfo>> { @@ -576,7 +578,7 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(file: P) -> Result<Vec<Te _ => continue, }; match rebuild_args(args) { - Ok((layout, dtype, file_path)) => { + Ok((layout, dtype, file_path, storage_size)) => { let mut path = dir_name.clone(); path.push(file_path); tensor_infos.push(TensorInfo { @@ -584,6 +586,7 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(file: P) -> Result<Vec<Te dtype, layout, path: path.to_string_lossy().into_owned(), + storage_size, }) } Err(err) => { |