diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-10-05 10:05:14 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-10-05 10:05:14 +0200 |
commit | d2e432914ec495baff1db29799fe316b9190b0e9 (patch) | |
tree | 616e67c53d1f8b4a7051e2416f105fce2835fd5d /tensor-tools/src | |
parent | 410c89f72a0ab22a299d02d24f505a50522faaa2 (diff) | |
download | candle-d2e432914ec495baff1db29799fe316b9190b0e9.tar.gz candle-d2e432914ec495baff1db29799fe316b9190b0e9.tar.bz2 candle-d2e432914ec495baff1db29799fe316b9190b0e9.zip |
Tensor tools print all (#2543)
* Support whisper large-v3 turbo in the whisper-microphone example.
* Print all tensors when no argument is provided.
Diffstat (limited to 'tensor-tools/src')
-rw-r--r-- | tensor-tools/src/main.rs | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/tensor-tools/src/main.rs b/tensor-tools/src/main.rs index ad351171..0bda36d5 100644 --- a/tensor-tools/src/main.rs +++ b/tensor-tools/src/main.rs @@ -197,6 +197,11 @@ fn run_print( match format { Format::Npz => { let tensors = candle::npy::NpzTensors::new(file)?; + let names = if names.is_empty() { + tensors.names().into_iter().map(|v| v.to_string()).collect() + } else { + names + }; for name in names.iter() { println!("==== {name} ===="); match tensors.get(name)? { @@ -209,6 +214,11 @@ fn run_print( use candle::safetensors::Load; let tensors = unsafe { candle::safetensors::MmapedSafetensors::new(file)? }; let tensors: std::collections::HashMap<_, _> = tensors.tensors().into_iter().collect(); + let names = if names.is_empty() { + tensors.keys().map(|v| v.to_string()).collect() + } else { + names + }; for name in names.iter() { println!("==== {name} ===="); match tensors.get(name) { @@ -222,6 +232,15 @@ fn run_print( } Format::Pth => { let pth_file = candle::pickle::PthTensors::new(file, None)?; + let names = if names.is_empty() { + pth_file + .tensor_infos() + .keys() + .map(|v| v.to_string()) + .collect() + } else { + names + }; for name in names.iter() { println!("==== {name} ===="); match pth_file.get(name)? { @@ -238,6 +257,11 @@ fn run_print( Format::Ggml => { let mut file = std::fs::File::open(file)?; let content = candle::quantized::ggml_file::Content::read(&mut file, device)?; + let names = if names.is_empty() { + content.tensors.keys().map(|v| v.to_string()).collect() + } else { + names + }; for name in names.iter() { println!("==== {name} ===="); match content.tensors.get(name) { @@ -252,6 +276,11 @@ fn run_print( Format::Gguf => { let mut file = std::fs::File::open(file)?; let content = gguf_file::Content::read(&mut file)?; + let names = if names.is_empty() { + content.tensor_infos.keys().map(|v| v.to_string()).collect() + } else { + names + }; for name in names.iter() { println!("==== {name} ===="); match content.tensor(&mut file, name, device) { |