summaryrefslogtreecommitdiff
path: root/candle-pyo3
diff options
context:
space:
mode:
authorLukas Kreussel <65088241+LLukas22@users.noreply.github.com>2023-10-19 22:46:21 +0200
committerGitHub <noreply@github.com>2023-10-19 21:46:21 +0100
commit6684b7127a6476f3b1642243c752d51e85857f37 (patch)
treea810f3ac0a38fd5151a945105927fa15a9be1dd3 /candle-pyo3
parent93c25e8844e8db2c697f1e6e9d4a06dac1ca3569 (diff)
downloadcandle-6684b7127a6476f3b1642243c752d51e85857f37.tar.gz
candle-6684b7127a6476f3b1642243c752d51e85857f37.tar.bz2
candle-6684b7127a6476f3b1642243c752d51e85857f37.zip
PyO3: Add pytorch like `.to()` operator to `candle.Tensor` (#1100)
* add `.to()` operator * Only allow each value to be provided once via `args` or `kwargs`
Diffstat (limited to 'candle-pyo3')
-rw-r--r--candle-pyo3/py_src/candle/__init__.pyi5
-rw-r--r--candle-pyo3/src/lib.rs106
-rw-r--r--candle-pyo3/tests/native/test_tensor.py65
3 files changed, 176 insertions, 0 deletions
diff --git a/candle-pyo3/py_src/candle/__init__.pyi b/candle-pyo3/py_src/candle/__init__.pyi
index 4096907b..7a0b2fcf 100644
--- a/candle-pyo3/py_src/candle/__init__.pyi
+++ b/candle-pyo3/py_src/candle/__init__.pyi
@@ -381,6 +381,11 @@ class Tensor:
Transposes the tensor.
"""
pass
+ def to(self, *args, **kwargs) -> Tensor:
+ """
+ Performs Tensor dtype and/or device conversion.
+ """
+ pass
def to_device(self, device: Union[str, Device]) -> Tensor:
"""
Move the tensor to a new device.
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index 55b20308..f9fdc712 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -772,6 +772,112 @@ impl PyTensor {
Ok(PyTensor(self.0.copy().map_err(wrap_err)?))
}
+ #[pyo3(signature = (*args, **kwargs), text_signature = "(self, *args, **kwargs)")]
+ /// Performs Tensor dtype and/or device conversion.
+ /// &RETURNS&: Tensor
+ fn to(&self, args: &PyTuple, kwargs: Option<&PyDict>) -> PyResult<Self> {
+ let mut device: Option<PyDevice> = None;
+ let mut dtype: Option<PyDType> = None;
+ let mut other: Option<PyTensor> = None;
+
+ fn handle_duplicates<T>(
+ opt: &mut Option<T>,
+ extraction_result: PyResult<T>,
+ err_msg: &'static str,
+ ) -> PyResult<()> {
+ if let Ok(sucessfull_extraction) = extraction_result {
+ if opt.is_some() {
+ return Err(PyValueError::new_err(err_msg));
+ }
+ *opt = Some(sucessfull_extraction);
+ }
+ Ok(())
+ }
+
+ //handle args
+ for arg in args.iter() {
+ if arg.extract::<PyDevice>().is_ok() {
+ handle_duplicates(
+ &mut device,
+ arg.extract::<PyDevice>(),
+ "cannot specify multiple devices",
+ )?;
+ } else if arg.extract::<PyDType>().is_ok() {
+ handle_duplicates(
+ &mut dtype,
+ arg.extract::<PyDType>(),
+ "cannot specify multiple dtypes",
+ )?;
+ } else if arg.extract::<PyTensor>().is_ok() {
+ handle_duplicates(
+ &mut other,
+ arg.extract::<PyTensor>(),
+ "cannot specify multiple output tensors",
+ )?;
+ } else {
+ return Err(PyTypeError::new_err(format!(
+ "unsupported argument type `{:#?}`",
+ arg.get_type().name()
+ )));
+ }
+ }
+
+ if let Some(kwargs) = kwargs {
+ if let Some(any) = kwargs.get_item("dtype") {
+ handle_duplicates(
+ &mut dtype,
+ any.extract::<PyDType>(),
+ "cannot specify multiple dtypes",
+ )?;
+ }
+ if let Some(any) = kwargs.get_item("device") {
+ handle_duplicates(
+ &mut device,
+ any.extract::<PyDevice>(),
+ "cannot specify multiple devices",
+ )?;
+ }
+ if let Some(any) = kwargs.get_item("other") {
+ handle_duplicates(
+ &mut other,
+ any.extract::<PyTensor>(),
+ "cannot specify multiple output tensors",
+ )?;
+ }
+ }
+
+ if let Some(other) = other {
+ if device.is_some() {
+ return Err(PyValueError::new_err(
+ "cannot specify both an output tensor and a device",
+ ));
+ }
+ if dtype.is_some() {
+ return Err(PyValueError::new_err(
+ "cannot specify both an output tensor and a dtype",
+ ));
+ }
+ dtype = Some(other.dtype());
+ device = Some(PyDevice::from_device(other.0.device()));
+ }
+
+ let result = match (device, dtype) {
+ (Some(device), Some(dtype)) => self
+ .0
+ .to_device(&device.as_device()?)
+ .map_err(wrap_err)?
+ .to_dtype(dtype.0)
+ .map_err(wrap_err)?,
+ (Some(device), None) => self.0.to_device(&device.as_device()?).map_err(wrap_err)?,
+ (None, Some(dtype)) => self.0.to_dtype(dtype.0).map_err(wrap_err)?,
+ (None, None) => {
+ return Err(PyTypeError::new_err("No valide dtype or device specified"))
+ }
+ };
+
+ Ok(PyTensor(result))
+ }
+
#[pyo3(text_signature = "(self, dtype:Union[str,DType])")]
/// Convert the tensor to a new dtype.
/// &RETURNS&: Tensor
diff --git a/candle-pyo3/tests/native/test_tensor.py b/candle-pyo3/tests/native/test_tensor.py
index 225a7469..659423e0 100644
--- a/candle-pyo3/tests/native/test_tensor.py
+++ b/candle-pyo3/tests/native/test_tensor.py
@@ -1,5 +1,6 @@
import candle
from candle import Tensor
+from candle.utils import cuda_is_available
import pytest
@@ -75,6 +76,70 @@ def test_tensor_can_be_scliced_3d():
assert t[..., 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]]
+def test_tensor_can_be_cast_via_to():
+ t = Tensor(42.0)
+ assert str(t.dtype) == str(candle.f32)
+ t_new_args = t.to(candle.f64)
+ assert str(t_new_args.dtype) == str(candle.f64)
+ t_new_kwargs = t.to(dtype=candle.f64)
+ assert str(t_new_kwargs.dtype) == str(candle.f64)
+ pytest.raises(TypeError, lambda: t.to("not a dtype"))
+ pytest.raises(TypeError, lambda: t.to(dtype="not a dtype"))
+ pytest.raises(TypeError, lambda: t.to(candle.f64, "not a dtype"))
+ pytest.raises(TypeError, lambda: t.to())
+ pytest.raises(ValueError, lambda: t.to(candle.f16, dtype=candle.f64))
+ pytest.raises(ValueError, lambda: t.to(candle.f16, candle.f16))
+
+ other = Tensor(42.0).to(candle.f64)
+ t_new_other_args = t.to(other)
+ assert str(t_new_other_args.dtype) == str(candle.f64)
+ t_new_other_kwargs = t.to(other=other)
+ assert str(t_new_other_kwargs.dtype) == str(candle.f64)
+
+
+@pytest.mark.skipif(not cuda_is_available(), reason="CUDA is not available")
+def test_tensor_can_be_moved_via_to():
+ t = Tensor(42.0)
+ assert t.device == "cpu"
+ t_new_args = t.to("cuda")
+ assert t_new_args.device == "cuda"
+ t_new_kwargs = t.to(device="cuda")
+ assert t_new_kwargs.device == "cuda"
+ pytest.raises(TypeError, lambda: t.to("not a device"))
+ pytest.raises(TypeError, lambda: t.to(device="not a device"))
+ pytest.raises(TypeError, lambda: t.to("cuda", "not a device"))
+ pytest.raises(TypeError, lambda: t.to())
+ pytest.raises(ValueError, lambda: t.to("cuda", device="cpu"))
+ pytest.raises(ValueError, lambda: t.to("cuda", "cuda"))
+
+ other = Tensor(42.0).to("cuda")
+ t_new_other_args = t.to(other)
+ assert t_new_other_args.device == "cuda"
+ t_new_other_kwargs = t.to(other=other)
+ assert t_new_other_kwargs.device == "cuda"
+
+
+@pytest.mark.skipif(not cuda_is_available(), reason="CUDA is not available")
+def test_tensor_can_be_moved_and_cast_via_to():
+ t = Tensor(42.0)
+ assert t.device == "cpu"
+ assert str(t.dtype) == str(candle.f32)
+ t_new_args = t.to("cuda", candle.f64)
+ assert t_new_args.device == "cuda"
+ assert str(t_new_args.dtype) == str(candle.f64)
+ t_new_kwargs = t.to(device="cuda", dtype=candle.f64)
+ assert t_new_kwargs.device == "cuda"
+ assert str(t_new_kwargs.dtype) == str(candle.f64)
+
+ other = Tensor(42.0).to("cuda").to(candle.f64)
+ t_new_other_args = t.to(other)
+ assert t_new_other_args.device == "cuda"
+ assert str(t_new_other_args.dtype) == str(candle.f64)
+ t_new_other_kwargs = t.to(other=other)
+ assert t_new_other_kwargs.device == "cuda"
+ assert str(t_new_other_kwargs.dtype) == str(candle.f64)
+
+
def test_tensor_can_be_added():
t = Tensor(42.0)
result = t + t