summaryrefslogtreecommitdiff
path: root/candle-core/tests
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/tests')
-rw-r--r--candle-core/tests/pth.py2
-rw-r--r--candle-core/tests/pth_tests.rs10
-rw-r--r--candle-core/tests/test_with_key.ptbin0 -> 1338 bytes
3 files changed, 11 insertions, 1 deletions
diff --git a/candle-core/tests/pth.py b/candle-core/tests/pth.py
index 97724712..cab94f2c 100644
--- a/candle-core/tests/pth.py
+++ b/candle-core/tests/pth.py
@@ -6,3 +6,5 @@ a= torch.tensor([[1,2,3,4], [5,6,7,8]])
o = OrderedDict()
o["test"] = a
torch.save(o, "test.pt")
+
+torch.save({"model_state_dict": o}, "test_with_key.pt")
diff --git a/candle-core/tests/pth_tests.rs b/candle-core/tests/pth_tests.rs
index b09d1026..ad788ed9 100644
--- a/candle-core/tests/pth_tests.rs
+++ b/candle-core/tests/pth_tests.rs
@@ -1,6 +1,14 @@
/// Regression test for pth files not loading on Windows.
#[test]
fn test_pth() {
- let tensors = candle_core::pickle::PthTensors::new("tests/test.pt").unwrap();
+ let tensors = candle_core::pickle::PthTensors::new("tests/test.pt", None).unwrap();
+ tensors.get("test").unwrap().unwrap();
+}
+
+#[test]
+fn test_pth_with_key() {
+ let tensors =
+ candle_core::pickle::PthTensors::new("tests/test_with_key.pt", Some("model_state_dict"))
+ .unwrap();
tensors.get("test").unwrap().unwrap();
}
diff --git a/candle-core/tests/test_with_key.pt b/candle-core/tests/test_with_key.pt
new file mode 100644
index 00000000..a598e02c
--- /dev/null
+++ b/candle-core/tests/test_with_key.pt
Binary files differ