summaryrefslogtreecommitdiff
path: root/candle-pyo3/src
diff options
context:
space:
mode:
authorandrew <trasuadev@gmail.com>2023-10-26 01:39:14 +0700
committerGitHub <noreply@github.com>2023-10-25 19:39:14 +0100
commit6a446d9d73da64daea4dc75e9b57dba78d4180fb (patch)
treed5a84cc1ab98f17e9d2739da66641deb5a1f5ef7 /candle-pyo3/src
parent0acd16751d6e0a501bba6c6285a18ccc40fad59b (diff)
downloadcandle-6a446d9d73da64daea4dc75e9b57dba78d4180fb.tar.gz
candle-6a446d9d73da64daea4dc75e9b57dba78d4180fb.tar.bz2
candle-6a446d9d73da64daea4dc75e9b57dba78d4180fb.zip
convert pytorch's tensor in Python API (#1172)
* convert pytorch's tensor * separate tests for convert pytorch tensor
Diffstat (limited to 'candle-pyo3/src')
-rw-r--r--candle-pyo3/src/lib.rs24
1 files changed, 24 insertions, 0 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index e2c8014f..6d4de80b 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -211,6 +211,16 @@ enum Indexer {
IndexSelect(Tensor),
}
+#[derive(Clone, Debug)]
+struct TorchTensor(PyObject);
+
+impl<'source> pyo3::FromPyObject<'source> for TorchTensor {
+ fn extract(ob: &'source PyAny) -> PyResult<Self> {
+ let numpy_value: PyObject = ob.getattr("numpy")?.call0()?.extract()?;
+ Ok(TorchTensor(numpy_value))
+ }
+}
+
#[pymethods]
impl PyTensor {
#[new]
@@ -246,6 +256,8 @@ impl PyTensor {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = data.extract::<Vec<Vec<Vec<f32>>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
+ } else if let Ok(TorchTensor(numpy)) = data.extract::<TorchTensor>(py) {
+ return PyTensor::new(py, numpy);
} else {
let ty = data.as_ref(py).get_type();
Err(PyTypeError::new_err(format!(
@@ -299,6 +311,18 @@ impl PyTensor {
M(py).map(self)
}
+ /// Converts candle's tensor to pytorch's tensor
+ /// &RETURNS&: torch.Tensor
+ fn to_torch(&self, py: Python<'_>) -> PyResult<PyObject> {
+ let candle_values = self.values(py)?;
+ let torch_tensor: PyObject = py
+ .import("torch")?
+ .getattr("tensor")?
+ .call1((candle_values,))?
+ .extract()?;
+ Ok(torch_tensor)
+ }
+
#[getter]
/// Gets the tensor's shape.
/// &RETURNS&: Tuple[int]