summaryrefslogtreecommitdiff
path: root/tensor-tools/src
diff options
context:
space:
mode:
Diffstat (limited to 'tensor-tools/src')
-rw-r--r--tensor-tools/src/main.rs29
1 files changed, 29 insertions, 0 deletions
diff --git a/tensor-tools/src/main.rs b/tensor-tools/src/main.rs
index ad351171..0bda36d5 100644
--- a/tensor-tools/src/main.rs
+++ b/tensor-tools/src/main.rs
@@ -197,6 +197,11 @@ fn run_print(
match format {
Format::Npz => {
let tensors = candle::npy::NpzTensors::new(file)?;
+ let names = if names.is_empty() {
+ tensors.names().into_iter().map(|v| v.to_string()).collect()
+ } else {
+ names
+ };
for name in names.iter() {
println!("==== {name} ====");
match tensors.get(name)? {
@@ -209,6 +214,11 @@ fn run_print(
use candle::safetensors::Load;
let tensors = unsafe { candle::safetensors::MmapedSafetensors::new(file)? };
let tensors: std::collections::HashMap<_, _> = tensors.tensors().into_iter().collect();
+ let names = if names.is_empty() {
+ tensors.keys().map(|v| v.to_string()).collect()
+ } else {
+ names
+ };
for name in names.iter() {
println!("==== {name} ====");
match tensors.get(name) {
@@ -222,6 +232,15 @@ fn run_print(
}
Format::Pth => {
let pth_file = candle::pickle::PthTensors::new(file, None)?;
+ let names = if names.is_empty() {
+ pth_file
+ .tensor_infos()
+ .keys()
+ .map(|v| v.to_string())
+ .collect()
+ } else {
+ names
+ };
for name in names.iter() {
println!("==== {name} ====");
match pth_file.get(name)? {
@@ -238,6 +257,11 @@ fn run_print(
Format::Ggml => {
let mut file = std::fs::File::open(file)?;
let content = candle::quantized::ggml_file::Content::read(&mut file, device)?;
+ let names = if names.is_empty() {
+ content.tensors.keys().map(|v| v.to_string()).collect()
+ } else {
+ names
+ };
for name in names.iter() {
println!("==== {name} ====");
match content.tensors.get(name) {
@@ -252,6 +276,11 @@ fn run_print(
Format::Gguf => {
let mut file = std::fs::File::open(file)?;
let content = gguf_file::Content::read(&mut file)?;
+ let names = if names.is_empty() {
+ content.tensor_infos.keys().map(|v| v.to_string()).collect()
+ } else {
+ names
+ };
for name in names.iter() {
println!("==== {name} ====");
match content.tensor(&mut file, name, device) {