summaryrefslogtreecommitdiff
path: root/candle-pyo3/test_pytorch.py
diff options
context:
space:
mode:
authorandrew <trasuadev@gmail.com>2023-10-26 01:39:14 +0700
committerGitHub <noreply@github.com>2023-10-25 19:39:14 +0100
commit6a446d9d73da64daea4dc75e9b57dba78d4180fb (patch)
treed5a84cc1ab98f17e9d2739da66641deb5a1f5ef7 /candle-pyo3/test_pytorch.py
parent0acd16751d6e0a501bba6c6285a18ccc40fad59b (diff)
downloadcandle-6a446d9d73da64daea4dc75e9b57dba78d4180fb.tar.gz
candle-6a446d9d73da64daea4dc75e9b57dba78d4180fb.tar.bz2
candle-6a446d9d73da64daea4dc75e9b57dba78d4180fb.zip
convert pytorch's tensor in Python API (#1172)
* convert pytorch's tensor * separate tests for convert pytorch tensor
Diffstat (limited to 'candle-pyo3/test_pytorch.py')
-rw-r--r--candle-pyo3/test_pytorch.py14
1 files changed, 14 insertions, 0 deletions
diff --git a/candle-pyo3/test_pytorch.py b/candle-pyo3/test_pytorch.py
new file mode 100644
index 00000000..db0f3522
--- /dev/null
+++ b/candle-pyo3/test_pytorch.py
@@ -0,0 +1,14 @@
+import candle
+import torch
+
+# convert from candle tensor to torch tensor
+t = candle.randn((3, 512, 512))
+torch_tensor = t.to_torch()
+print(torch_tensor)
+print(type(torch_tensor))
+
+# convert from torch tensor to candle tensor
+t = torch.randn((3, 512, 512))
+candle_tensor = candle.Tensor(t)
+print(candle_tensor)
+print(type(candle_tensor))