summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/pickle.rs13
1 files changed, 13 insertions, 0 deletions
diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs
index 37c15018..0013113a 100644
--- a/candle-core/src/pickle.rs
+++ b/candle-core/src/pickle.rs
@@ -723,3 +723,16 @@ impl PthTensors {
Ok(Some(tensor))
}
}
+
+/// Read all the tensors from a PyTorch pth file.
+pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
+ let pth = PthTensors::new(path)?;
+ let tensor_names = pth.tensor_infos.keys();
+ let mut tensors = Vec::with_capacity(tensor_names.len());
+ for name in tensor_names {
+ if let Some(tensor) = pth.get(name)? {
+ tensors.push((name.to_string(), tensor))
+ }
+ }
+ Ok(tensors)
+}