diff options
author | andrew <trasuadev@gmail.com> | 2023-10-26 01:39:14 +0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-25 19:39:14 +0100 |
commit | 6a446d9d73da64daea4dc75e9b57dba78d4180fb (patch) | |
tree | d5a84cc1ab98f17e9d2739da66641deb5a1f5ef7 /candle-pyo3/src | |
parent | 0acd16751d6e0a501bba6c6285a18ccc40fad59b (diff) | |
download | candle-6a446d9d73da64daea4dc75e9b57dba78d4180fb.tar.gz candle-6a446d9d73da64daea4dc75e9b57dba78d4180fb.tar.bz2 candle-6a446d9d73da64daea4dc75e9b57dba78d4180fb.zip |
convert pytorch's tensor in Python API (#1172)
* convert pytorch's tensor
* separate tests for convert pytorch tensor
Diffstat (limited to 'candle-pyo3/src')
-rw-r--r-- | candle-pyo3/src/lib.rs | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index e2c8014f..6d4de80b 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -211,6 +211,16 @@ enum Indexer { IndexSelect(Tensor), } +#[derive(Clone, Debug)] +struct TorchTensor(PyObject); + +impl<'source> pyo3::FromPyObject<'source> for TorchTensor { + fn extract(ob: &'source PyAny) -> PyResult<Self> { + let numpy_value: PyObject = ob.getattr("numpy")?.call0()?.extract()?; + Ok(TorchTensor(numpy_value)) + } +} + #[pymethods] impl PyTensor { #[new] @@ -246,6 +256,8 @@ impl PyTensor { Tensor::new(vs, &Cpu).map_err(wrap_err)? } else if let Ok(vs) = data.extract::<Vec<Vec<Vec<f32>>>>(py) { Tensor::new(vs, &Cpu).map_err(wrap_err)? + } else if let Ok(TorchTensor(numpy)) = data.extract::<TorchTensor>(py) { + return PyTensor::new(py, numpy); } else { let ty = data.as_ref(py).get_type(); Err(PyTypeError::new_err(format!( @@ -299,6 +311,18 @@ impl PyTensor { M(py).map(self) } + /// Converts candle's tensor to pytorch's tensor + /// &RETURNS&: torch.Tensor + fn to_torch(&self, py: Python<'_>) -> PyResult<PyObject> { + let candle_values = self.values(py)?; + let torch_tensor: PyObject = py + .import("torch")? + .getattr("tensor")? + .call1((candle_values,))? + .extract()?; + Ok(torch_tensor) + } + #[getter] /// Gets the tensor's shape. /// &RETURNS&: Tuple[int] |