summaryrefslogtreecommitdiff
path: root/tensor-tools/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-10-05 10:05:14 +0200
committerGitHub <noreply@github.com>2024-10-05 10:05:14 +0200
commitd2e432914ec495baff1db29799fe316b9190b0e9 (patch)
tree616e67c53d1f8b4a7051e2416f105fce2835fd5d /tensor-tools/src
parent410c89f72a0ab22a299d02d24f505a50522faaa2 (diff)
downloadcandle-d2e432914ec495baff1db29799fe316b9190b0e9.tar.gz
candle-d2e432914ec495baff1db29799fe316b9190b0e9.tar.bz2
candle-d2e432914ec495baff1db29799fe316b9190b0e9.zip
Tensor tools print all (#2543)
* Support whisper large-v3 turbo in the whisper-microphone example. * Print all tensors when no argument is provided.
Diffstat (limited to 'tensor-tools/src')
-rw-r--r--tensor-tools/src/main.rs29
1 files changed, 29 insertions, 0 deletions
diff --git a/tensor-tools/src/main.rs b/tensor-tools/src/main.rs
index ad351171..0bda36d5 100644
--- a/tensor-tools/src/main.rs
+++ b/tensor-tools/src/main.rs
@@ -197,6 +197,11 @@ fn run_print(
match format {
Format::Npz => {
let tensors = candle::npy::NpzTensors::new(file)?;
+ let names = if names.is_empty() {
+ tensors.names().into_iter().map(|v| v.to_string()).collect()
+ } else {
+ names
+ };
for name in names.iter() {
println!("==== {name} ====");
match tensors.get(name)? {
@@ -209,6 +214,11 @@ fn run_print(
use candle::safetensors::Load;
let tensors = unsafe { candle::safetensors::MmapedSafetensors::new(file)? };
let tensors: std::collections::HashMap<_, _> = tensors.tensors().into_iter().collect();
+ let names = if names.is_empty() {
+ tensors.keys().map(|v| v.to_string()).collect()
+ } else {
+ names
+ };
for name in names.iter() {
println!("==== {name} ====");
match tensors.get(name) {
@@ -222,6 +232,15 @@ fn run_print(
}
Format::Pth => {
let pth_file = candle::pickle::PthTensors::new(file, None)?;
+ let names = if names.is_empty() {
+ pth_file
+ .tensor_infos()
+ .keys()
+ .map(|v| v.to_string())
+ .collect()
+ } else {
+ names
+ };
for name in names.iter() {
println!("==== {name} ====");
match pth_file.get(name)? {
@@ -238,6 +257,11 @@ fn run_print(
Format::Ggml => {
let mut file = std::fs::File::open(file)?;
let content = candle::quantized::ggml_file::Content::read(&mut file, device)?;
+ let names = if names.is_empty() {
+ content.tensors.keys().map(|v| v.to_string()).collect()
+ } else {
+ names
+ };
for name in names.iter() {
println!("==== {name} ====");
match content.tensors.get(name) {
@@ -252,6 +276,11 @@ fn run_print(
Format::Gguf => {
let mut file = std::fs::File::open(file)?;
let content = gguf_file::Content::read(&mut file)?;
+ let names = if names.is_empty() {
+ content.tensor_infos.keys().map(|v| v.to_string()).collect()
+ } else {
+ names
+ };
for name in names.iter() {
println!("==== {name} ====");
match content.tensor(&mut file, name, device) {