summaryrefslogtreecommitdiff
path: root/candle-pyo3/test_pytorch.py
blob: db0f35227c0370a46d33eb88d52f4107c2eba95e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import candle
import torch

# convert from candle tensor to torch tensor
t = candle.randn((3, 512, 512))
torch_tensor = t.to_torch()
print(torch_tensor)
print(type(torch_tensor))

# convert from torch tensor to candle tensor
t = torch.randn((3, 512, 512))
candle_tensor = candle.Tensor(t)
print(candle_tensor)
print(type(candle_tensor))