summaryrefslogtreecommitdiff
path: root/candle-pyo3/py_src
diff options
context:
space:
mode:
authorandrew <trasuadev@gmail.com>2023-10-26 01:39:14 +0700
committerGitHub <noreply@github.com>2023-10-25 19:39:14 +0100
commit6a446d9d73da64daea4dc75e9b57dba78d4180fb (patch)
treed5a84cc1ab98f17e9d2739da66641deb5a1f5ef7 /candle-pyo3/py_src
parent0acd16751d6e0a501bba6c6285a18ccc40fad59b (diff)
downloadcandle-6a446d9d73da64daea4dc75e9b57dba78d4180fb.tar.gz
candle-6a446d9d73da64daea4dc75e9b57dba78d4180fb.tar.bz2
candle-6a446d9d73da64daea4dc75e9b57dba78d4180fb.zip
convert pytorch's tensor in Python API (#1172)
* convert pytorch's tensor * separate tests for convert pytorch tensor
Diffstat (limited to 'candle-pyo3/py_src')
-rw-r--r--candle-pyo3/py_src/candle/__init__.pyi5
1 files changed, 5 insertions, 0 deletions
diff --git a/candle-pyo3/py_src/candle/__init__.pyi b/candle-pyo3/py_src/candle/__init__.pyi
index 7a0b2fcf..43722168 100644
--- a/candle-pyo3/py_src/candle/__init__.pyi
+++ b/candle-pyo3/py_src/candle/__init__.pyi
@@ -396,6 +396,11 @@ class Tensor:
Convert the tensor to a new dtype.
"""
pass
+ def to_torch(self) -> torch.Tensor:
+ """
+ Converts candle's tensor to pytorch's tensor
+ """
+ pass
def transpose(self, dim1: int, dim2: int) -> Tensor:
"""
Returns a tensor that is a transposed version of the input, the given dimensions are swapped.