summaryrefslogtreecommitdiff
path: root/candle-core/examples
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-19 15:05:34 +0100
committerGitHub <noreply@github.com>2023-08-19 15:05:34 +0100
commit607ffb9f1edebfb74f30b82a0731c16a704f9f9b (patch)
tree74b090b6af871c33b5dd4fd1a5b463a78b14a215 /candle-core/examples
parentf861a9df6ef35bf5e2df5891d3af029e9139b0d8 (diff)
downloadcandle-607ffb9f1edebfb74f30b82a0731c16a704f9f9b.tar.gz
candle-607ffb9f1edebfb74f30b82a0731c16a704f9f9b.tar.bz2
candle-607ffb9f1edebfb74f30b82a0731c16a704f9f9b.zip
Retrieve more information from PyTorch checkpoints. (#515)
* Retrieve more information from PyTorch checkpoints. * Add enough support to load dino-v2 backbone weights.
Diffstat (limited to 'candle-core/examples')
-rw-r--r--candle-core/examples/tensor-tools.rs12
1 files changed, 9 insertions, 3 deletions
diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs
index 229ed489..1e01d8b9 100644
--- a/candle-core/examples/tensor-tools.rs
+++ b/candle-core/examples/tensor-tools.rs
@@ -88,9 +88,15 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>) -> Result<()> {
}
Format::PyTorch => {
let mut tensors = candle_core::pickle::read_pth_tensor_info(file)?;
- tensors.sort_by(|a, b| a.0.cmp(&b.0));
- for (name, dtype, shape) in tensors.iter() {
- println!("{name}: [{shape:?}; {dtype:?}]")
+ tensors.sort_by(|a, b| a.name.cmp(&b.name));
+ for tensor_info in tensors.iter() {
+ println!(
+ "{}: [{:?}; {:?}] {:?}",
+ tensor_info.name,
+ tensor_info.layout.shape(),
+ tensor_info.dtype,
+ tensor_info.path,
+ )
}
}
Format::Pickle => {