diff options
author | Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> | 2023-10-19 22:46:21 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-19 21:46:21 +0100 |
commit | 6684b7127a6476f3b1642243c752d51e85857f37 (patch) | |
tree | a810f3ac0a38fd5151a945105927fa15a9be1dd3 /candle-pyo3 | |
parent | 93c25e8844e8db2c697f1e6e9d4a06dac1ca3569 (diff) | |
download | candle-6684b7127a6476f3b1642243c752d51e85857f37.tar.gz candle-6684b7127a6476f3b1642243c752d51e85857f37.tar.bz2 candle-6684b7127a6476f3b1642243c752d51e85857f37.zip |
PyO3: Add pytorch like `.to()` operator to `candle.Tensor` (#1100)
* add `.to()` operator
* Only allow each value to be provided once via `args` or `kwargs`
Diffstat (limited to 'candle-pyo3')
-rw-r--r-- | candle-pyo3/py_src/candle/__init__.pyi | 5 | ||||
-rw-r--r-- | candle-pyo3/src/lib.rs | 106 | ||||
-rw-r--r-- | candle-pyo3/tests/native/test_tensor.py | 65 |
3 files changed, 176 insertions, 0 deletions
diff --git a/candle-pyo3/py_src/candle/__init__.pyi b/candle-pyo3/py_src/candle/__init__.pyi index 4096907b..7a0b2fcf 100644 --- a/candle-pyo3/py_src/candle/__init__.pyi +++ b/candle-pyo3/py_src/candle/__init__.pyi @@ -381,6 +381,11 @@ class Tensor: Transposes the tensor. """ pass + def to(self, *args, **kwargs) -> Tensor: + """ + Performs Tensor dtype and/or device conversion. + """ + pass def to_device(self, device: Union[str, Device]) -> Tensor: """ Move the tensor to a new device. diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 55b20308..f9fdc712 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -772,6 +772,112 @@ impl PyTensor { Ok(PyTensor(self.0.copy().map_err(wrap_err)?)) } + #[pyo3(signature = (*args, **kwargs), text_signature = "(self, *args, **kwargs)")] + /// Performs Tensor dtype and/or device conversion. + /// &RETURNS&: Tensor + fn to(&self, args: &PyTuple, kwargs: Option<&PyDict>) -> PyResult<Self> { + let mut device: Option<PyDevice> = None; + let mut dtype: Option<PyDType> = None; + let mut other: Option<PyTensor> = None; + + fn handle_duplicates<T>( + opt: &mut Option<T>, + extraction_result: PyResult<T>, + err_msg: &'static str, + ) -> PyResult<()> { + if let Ok(sucessfull_extraction) = extraction_result { + if opt.is_some() { + return Err(PyValueError::new_err(err_msg)); + } + *opt = Some(sucessfull_extraction); + } + Ok(()) + } + + //handle args + for arg in args.iter() { + if arg.extract::<PyDevice>().is_ok() { + handle_duplicates( + &mut device, + arg.extract::<PyDevice>(), + "cannot specify multiple devices", + )?; + } else if arg.extract::<PyDType>().is_ok() { + handle_duplicates( + &mut dtype, + arg.extract::<PyDType>(), + "cannot specify multiple dtypes", + )?; + } else if arg.extract::<PyTensor>().is_ok() { + handle_duplicates( + &mut other, + arg.extract::<PyTensor>(), + "cannot specify multiple output tensors", + )?; + } else { + return Err(PyTypeError::new_err(format!( + "unsupported argument type `{:#?}`", + arg.get_type().name() + ))); + } + } + + if let Some(kwargs) = kwargs { + if let Some(any) = kwargs.get_item("dtype") { + handle_duplicates( + &mut dtype, + any.extract::<PyDType>(), + "cannot specify multiple dtypes", + )?; + } + if let Some(any) = kwargs.get_item("device") { + handle_duplicates( + &mut device, + any.extract::<PyDevice>(), + "cannot specify multiple devices", + )?; + } + if let Some(any) = kwargs.get_item("other") { + handle_duplicates( + &mut other, + any.extract::<PyTensor>(), + "cannot specify multiple output tensors", + )?; + } + } + + if let Some(other) = other { + if device.is_some() { + return Err(PyValueError::new_err( + "cannot specify both an output tensor and a device", + )); + } + if dtype.is_some() { + return Err(PyValueError::new_err( + "cannot specify both an output tensor and a dtype", + )); + } + dtype = Some(other.dtype()); + device = Some(PyDevice::from_device(other.0.device())); + } + + let result = match (device, dtype) { + (Some(device), Some(dtype)) => self + .0 + .to_device(&device.as_device()?) + .map_err(wrap_err)? + .to_dtype(dtype.0) + .map_err(wrap_err)?, + (Some(device), None) => self.0.to_device(&device.as_device()?).map_err(wrap_err)?, + (None, Some(dtype)) => self.0.to_dtype(dtype.0).map_err(wrap_err)?, + (None, None) => { + return Err(PyTypeError::new_err("No valide dtype or device specified")) + } + }; + + Ok(PyTensor(result)) + } + #[pyo3(text_signature = "(self, dtype:Union[str,DType])")] /// Convert the tensor to a new dtype. /// &RETURNS&: Tensor diff --git a/candle-pyo3/tests/native/test_tensor.py b/candle-pyo3/tests/native/test_tensor.py index 225a7469..659423e0 100644 --- a/candle-pyo3/tests/native/test_tensor.py +++ b/candle-pyo3/tests/native/test_tensor.py @@ -1,5 +1,6 @@ import candle from candle import Tensor +from candle.utils import cuda_is_available import pytest @@ -75,6 +76,70 @@ def test_tensor_can_be_scliced_3d(): assert t[..., 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]] +def test_tensor_can_be_cast_via_to(): + t = Tensor(42.0) + assert str(t.dtype) == str(candle.f32) + t_new_args = t.to(candle.f64) + assert str(t_new_args.dtype) == str(candle.f64) + t_new_kwargs = t.to(dtype=candle.f64) + assert str(t_new_kwargs.dtype) == str(candle.f64) + pytest.raises(TypeError, lambda: t.to("not a dtype")) + pytest.raises(TypeError, lambda: t.to(dtype="not a dtype")) + pytest.raises(TypeError, lambda: t.to(candle.f64, "not a dtype")) + pytest.raises(TypeError, lambda: t.to()) + pytest.raises(ValueError, lambda: t.to(candle.f16, dtype=candle.f64)) + pytest.raises(ValueError, lambda: t.to(candle.f16, candle.f16)) + + other = Tensor(42.0).to(candle.f64) + t_new_other_args = t.to(other) + assert str(t_new_other_args.dtype) == str(candle.f64) + t_new_other_kwargs = t.to(other=other) + assert str(t_new_other_kwargs.dtype) == str(candle.f64) + + +@pytest.mark.skipif(not cuda_is_available(), reason="CUDA is not available") +def test_tensor_can_be_moved_via_to(): + t = Tensor(42.0) + assert t.device == "cpu" + t_new_args = t.to("cuda") + assert t_new_args.device == "cuda" + t_new_kwargs = t.to(device="cuda") + assert t_new_kwargs.device == "cuda" + pytest.raises(TypeError, lambda: t.to("not a device")) + pytest.raises(TypeError, lambda: t.to(device="not a device")) + pytest.raises(TypeError, lambda: t.to("cuda", "not a device")) + pytest.raises(TypeError, lambda: t.to()) + pytest.raises(ValueError, lambda: t.to("cuda", device="cpu")) + pytest.raises(ValueError, lambda: t.to("cuda", "cuda")) + + other = Tensor(42.0).to("cuda") + t_new_other_args = t.to(other) + assert t_new_other_args.device == "cuda" + t_new_other_kwargs = t.to(other=other) + assert t_new_other_kwargs.device == "cuda" + + +@pytest.mark.skipif(not cuda_is_available(), reason="CUDA is not available") +def test_tensor_can_be_moved_and_cast_via_to(): + t = Tensor(42.0) + assert t.device == "cpu" + assert str(t.dtype) == str(candle.f32) + t_new_args = t.to("cuda", candle.f64) + assert t_new_args.device == "cuda" + assert str(t_new_args.dtype) == str(candle.f64) + t_new_kwargs = t.to(device="cuda", dtype=candle.f64) + assert t_new_kwargs.device == "cuda" + assert str(t_new_kwargs.dtype) == str(candle.f64) + + other = Tensor(42.0).to("cuda").to(candle.f64) + t_new_other_args = t.to(other) + assert t_new_other_args.device == "cuda" + assert str(t_new_other_args.dtype) == str(candle.f64) + t_new_other_kwargs = t.to(other=other) + assert t_new_other_kwargs.device == "cuda" + assert str(t_new_other_kwargs.dtype) == str(candle.f64) + + def test_tensor_can_be_added(): t = Tensor(42.0) result = t + t |