From 6a446d9d73da64daea4dc75e9b57dba78d4180fb Mon Sep 17 00:00:00 2001 From: andrew Date: Thu, 26 Oct 2023 01:39:14 +0700 Subject: convert pytorch's tensor in Python API (#1172) * convert pytorch's tensor * separate tests for convert pytorch tensor --- candle-pyo3/test_pytorch.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 candle-pyo3/test_pytorch.py (limited to 'candle-pyo3/test_pytorch.py') diff --git a/candle-pyo3/test_pytorch.py b/candle-pyo3/test_pytorch.py new file mode 100644 index 00000000..db0f3522 --- /dev/null +++ b/candle-pyo3/test_pytorch.py @@ -0,0 +1,14 @@ +import candle +import torch + +# convert from candle tensor to torch tensor +t = candle.randn((3, 512, 512)) +torch_tensor = t.to_torch() +print(torch_tensor) +print(type(torch_tensor)) + +# convert from torch tensor to candle tensor +t = torch.randn((3, 512, 512)) +candle_tensor = candle.Tensor(t) +print(candle_tensor) +print(type(candle_tensor)) -- cgit v1.2.3