diff options
Diffstat (limited to 'candle-core/tests')
-rw-r--r-- | candle-core/tests/pth.py | 2 | ||||
-rw-r--r-- | candle-core/tests/pth_tests.rs | 10 | ||||
-rw-r--r-- | candle-core/tests/test_with_key.pt | bin | 0 -> 1338 bytes |
3 files changed, 11 insertions, 1 deletions
diff --git a/candle-core/tests/pth.py b/candle-core/tests/pth.py index 97724712..cab94f2c 100644 --- a/candle-core/tests/pth.py +++ b/candle-core/tests/pth.py @@ -6,3 +6,5 @@ a= torch.tensor([[1,2,3,4], [5,6,7,8]]) o = OrderedDict() o["test"] = a torch.save(o, "test.pt") + +torch.save({"model_state_dict": o}, "test_with_key.pt") diff --git a/candle-core/tests/pth_tests.rs b/candle-core/tests/pth_tests.rs index b09d1026..ad788ed9 100644 --- a/candle-core/tests/pth_tests.rs +++ b/candle-core/tests/pth_tests.rs @@ -1,6 +1,14 @@ /// Regression test for pth files not loading on Windows. #[test] fn test_pth() { - let tensors = candle_core::pickle::PthTensors::new("tests/test.pt").unwrap(); + let tensors = candle_core::pickle::PthTensors::new("tests/test.pt", None).unwrap(); + tensors.get("test").unwrap().unwrap(); +} + +#[test] +fn test_pth_with_key() { + let tensors = + candle_core::pickle::PthTensors::new("tests/test_with_key.pt", Some("model_state_dict")) + .unwrap(); tensors.get("test").unwrap().unwrap(); } diff --git a/candle-core/tests/test_with_key.pt b/candle-core/tests/test_with_key.pt Binary files differnew file mode 100644 index 00000000..a598e02c --- /dev/null +++ b/candle-core/tests/test_with_key.pt |