summaryrefslogtreecommitdiff
path: root/candle-pyo3/test_pytorch.py
diff options
context:
space:
mode:
Diffstat (limited to 'candle-pyo3/test_pytorch.py')
-rw-r--r--candle-pyo3/test_pytorch.py14
1 files changed, 14 insertions, 0 deletions
diff --git a/candle-pyo3/test_pytorch.py b/candle-pyo3/test_pytorch.py
new file mode 100644
index 00000000..db0f3522
--- /dev/null
+++ b/candle-pyo3/test_pytorch.py
@@ -0,0 +1,14 @@
+import candle
+import torch
+
+# convert from candle tensor to torch tensor
+t = candle.randn((3, 512, 512))
+torch_tensor = t.to_torch()
+print(torch_tensor)
+print(type(torch_tensor))
+
+# convert from torch tensor to candle tensor
+t = torch.randn((3, 512, 512))
+candle_tensor = candle.Tensor(t)
+print(candle_tensor)
+print(type(candle_tensor))