summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-pyo3/py_src/candle/__init__.pyi5
-rw-r--r--candle-pyo3/src/lib.rs24
-rw-r--r--candle-pyo3/test_pytorch.py14
3 files changed, 43 insertions, 0 deletions
diff --git a/candle-pyo3/py_src/candle/__init__.pyi b/candle-pyo3/py_src/candle/__init__.pyi
index 7a0b2fcf..43722168 100644
--- a/candle-pyo3/py_src/candle/__init__.pyi
+++ b/candle-pyo3/py_src/candle/__init__.pyi
@@ -396,6 +396,11 @@ class Tensor:
Convert the tensor to a new dtype.
"""
pass
+ def to_torch(self) -> torch.Tensor:
+ """
+ Converts candle's tensor to pytorch's tensor
+ """
+ pass
def transpose(self, dim1: int, dim2: int) -> Tensor:
"""
Returns a tensor that is a transposed version of the input, the given dimensions are swapped.
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index e2c8014f..6d4de80b 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -211,6 +211,16 @@ enum Indexer {
IndexSelect(Tensor),
}
+#[derive(Clone, Debug)]
+struct TorchTensor(PyObject);
+
+impl<'source> pyo3::FromPyObject<'source> for TorchTensor {
+ fn extract(ob: &'source PyAny) -> PyResult<Self> {
+ let numpy_value: PyObject = ob.getattr("numpy")?.call0()?.extract()?;
+ Ok(TorchTensor(numpy_value))
+ }
+}
+
#[pymethods]
impl PyTensor {
#[new]
@@ -246,6 +256,8 @@ impl PyTensor {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = data.extract::<Vec<Vec<Vec<f32>>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
+ } else if let Ok(TorchTensor(numpy)) = data.extract::<TorchTensor>(py) {
+ return PyTensor::new(py, numpy);
} else {
let ty = data.as_ref(py).get_type();
Err(PyTypeError::new_err(format!(
@@ -299,6 +311,18 @@ impl PyTensor {
M(py).map(self)
}
+ /// Converts candle's tensor to pytorch's tensor
+ /// &RETURNS&: torch.Tensor
+ fn to_torch(&self, py: Python<'_>) -> PyResult<PyObject> {
+ let candle_values = self.values(py)?;
+ let torch_tensor: PyObject = py
+ .import("torch")?
+ .getattr("tensor")?
+ .call1((candle_values,))?
+ .extract()?;
+ Ok(torch_tensor)
+ }
+
#[getter]
/// Gets the tensor's shape.
/// &RETURNS&: Tuple[int]
diff --git a/candle-pyo3/test_pytorch.py b/candle-pyo3/test_pytorch.py
new file mode 100644
index 00000000..db0f3522
--- /dev/null
+++ b/candle-pyo3/test_pytorch.py
@@ -0,0 +1,14 @@
+import candle
+import torch
+
+# convert from candle tensor to torch tensor
+t = candle.randn((3, 512, 512))
+torch_tensor = t.to_torch()
+print(torch_tensor)
+print(type(torch_tensor))
+
+# convert from torch tensor to candle tensor
+t = torch.randn((3, 512, 512))
+candle_tensor = candle.Tensor(t)
+print(candle_tensor)
+print(type(candle_tensor))