summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/pickle.rs4
-rw-r--r--candle-core/tests/pth.py8
-rw-r--r--candle-core/tests/pth_tests.rs6
-rw-r--r--candle-core/tests/test.ptbin0 -> 1165 bytes
4 files changed, 15 insertions, 3 deletions
diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs
index 276b30e3..4c76c416 100644
--- a/candle-core/src/pickle.rs
+++ b/candle-core/src/pickle.rs
@@ -227,13 +227,11 @@ impl Object {
_ => return Ok(None),
};
let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
- let mut path = dir_name.to_path_buf();
- path.push(file_path);
Ok(Some(TensorInfo {
name,
dtype,
layout,
- path: path.to_string_lossy().into_owned(),
+ path: format!("{}/{}", dir_name.to_string_lossy(), file_path),
storage_size,
}))
}
diff --git a/candle-core/tests/pth.py b/candle-core/tests/pth.py
new file mode 100644
index 00000000..97724712
--- /dev/null
+++ b/candle-core/tests/pth.py
@@ -0,0 +1,8 @@
+import torch
+from collections import OrderedDict
+
+# Write a trivial tensor to a pt file
+a= torch.tensor([[1,2,3,4], [5,6,7,8]])
+o = OrderedDict()
+o["test"] = a
+torch.save(o, "test.pt")
diff --git a/candle-core/tests/pth_tests.rs b/candle-core/tests/pth_tests.rs
new file mode 100644
index 00000000..16bac526
--- /dev/null
+++ b/candle-core/tests/pth_tests.rs
@@ -0,0 +1,6 @@
+/// Regression test for pth files not loading on Windows.
+#[test]
+fn test_pth() {
+ let tensors = candle_core::pickle::PthTensors::new("tests/test.pt").unwrap();
+ tensors.get("test").unwrap().unwrap();
+}
diff --git a/candle-core/tests/test.pt b/candle-core/tests/test.pt
new file mode 100644
index 00000000..f2fa7da3
--- /dev/null
+++ b/candle-core/tests/test.pt
Binary files differ