diff options
-rw-r--r-- | candle-core/src/pickle.rs | 4 | ||||
-rw-r--r-- | candle-core/tests/pth.py | 8 | ||||
-rw-r--r-- | candle-core/tests/pth_tests.rs | 6 | ||||
-rw-r--r-- | candle-core/tests/test.pt | bin | 0 -> 1165 bytes |
4 files changed, 15 insertions, 3 deletions
diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs index 276b30e3..4c76c416 100644 --- a/candle-core/src/pickle.rs +++ b/candle-core/src/pickle.rs @@ -227,13 +227,11 @@ impl Object { _ => return Ok(None), }; let (layout, dtype, file_path, storage_size) = rebuild_args(args)?; - let mut path = dir_name.to_path_buf(); - path.push(file_path); Ok(Some(TensorInfo { name, dtype, layout, - path: path.to_string_lossy().into_owned(), + path: format!("{}/{}", dir_name.to_string_lossy(), file_path), storage_size, })) } diff --git a/candle-core/tests/pth.py b/candle-core/tests/pth.py new file mode 100644 index 00000000..97724712 --- /dev/null +++ b/candle-core/tests/pth.py @@ -0,0 +1,8 @@ +import torch +from collections import OrderedDict + +# Write a trivial tensor to a pt file +a= torch.tensor([[1,2,3,4], [5,6,7,8]]) +o = OrderedDict() +o["test"] = a +torch.save(o, "test.pt") diff --git a/candle-core/tests/pth_tests.rs b/candle-core/tests/pth_tests.rs new file mode 100644 index 00000000..16bac526 --- /dev/null +++ b/candle-core/tests/pth_tests.rs @@ -0,0 +1,6 @@ +/// Regression test for pth files not loading on Windows. +#[test] +fn test_pth() { + let tensors = candle_core::pickle::PthTensors::new("tests/test.pt").unwrap(); + tensors.get("test").unwrap().unwrap(); +} diff --git a/candle-core/tests/test.pt b/candle-core/tests/test.pt Binary files differnew file mode 100644 index 00000000..f2fa7da3 --- /dev/null +++ b/candle-core/tests/test.pt |