diff options
author | andrew <trasuadev@gmail.com> | 2023-10-26 01:39:14 +0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-25 19:39:14 +0100 |
commit | 6a446d9d73da64daea4dc75e9b57dba78d4180fb (patch) | |
tree | d5a84cc1ab98f17e9d2739da66641deb5a1f5ef7 /candle-pyo3/test_pytorch.py | |
parent | 0acd16751d6e0a501bba6c6285a18ccc40fad59b (diff) | |
download | candle-6a446d9d73da64daea4dc75e9b57dba78d4180fb.tar.gz candle-6a446d9d73da64daea4dc75e9b57dba78d4180fb.tar.bz2 candle-6a446d9d73da64daea4dc75e9b57dba78d4180fb.zip |
convert pytorch's tensor in Python API (#1172)
* convert pytorch's tensor
* separate tests for convert pytorch tensor
Diffstat (limited to 'candle-pyo3/test_pytorch.py')
-rw-r--r-- | candle-pyo3/test_pytorch.py | 14 |
1 files changed, 14 insertions, 0 deletions
diff --git a/candle-pyo3/test_pytorch.py b/candle-pyo3/test_pytorch.py new file mode 100644 index 00000000..db0f3522 --- /dev/null +++ b/candle-pyo3/test_pytorch.py @@ -0,0 +1,14 @@ +import candle +import torch + +# convert from candle tensor to torch tensor +t = candle.randn((3, 512, 512)) +torch_tensor = t.to_torch() +print(torch_tensor) +print(type(torch_tensor)) + +# convert from torch tensor to candle tensor +t = torch.randn((3, 512, 512)) +candle_tensor = candle.Tensor(t) +print(candle_tensor) +print(type(candle_tensor)) |