diff options
author | Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> | 2023-10-29 16:41:44 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-29 15:41:44 +0000 |
commit | 174b20805230abaf91b838598d84ab142f31a975 (patch) | |
tree | 775e1ca849f2e9f0ac2a32df476106ec1bb3a5dc /candle-pyo3 | |
parent | 154c674a798fd5a40d57ff9a8664856d9c41ca56 (diff) | |
download | candle-174b20805230abaf91b838598d84ab142f31a975.tar.gz candle-174b20805230abaf91b838598d84ab142f31a975.tar.bz2 candle-174b20805230abaf91b838598d84ab142f31a975.zip |
PyO3: Better shape handling (#1143)
* Negative and `*args` shape handling
* Rename to `PyShapeWithHole` + validate that only one hole exists
* Regenerate stubs
---------
Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-pyo3')
-rw-r--r-- | candle-pyo3/Cargo.toml | 4 | ||||
-rw-r--r-- | candle-pyo3/py_src/candle/__init__.pyi | 16 | ||||
-rw-r--r-- | candle-pyo3/py_src/candle/functional/__init__.pyi | 2 | ||||
-rw-r--r-- | candle-pyo3/py_src/candle/typing/__init__.py | 2 | ||||
-rw-r--r-- | candle-pyo3/py_src/candle/utils/__init__.pyi | 2 | ||||
-rw-r--r-- | candle-pyo3/src/lib.rs | 71 | ||||
-rw-r--r-- | candle-pyo3/src/shape.rs | 99 | ||||
-rw-r--r-- | candle-pyo3/stub.py | 2 | ||||
-rw-r--r-- | candle-pyo3/tests/native/test_shape.py | 31 |
9 files changed, 180 insertions, 49 deletions
diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml index 488404bf..b0452404 100644 --- a/candle-pyo3/Cargo.toml +++ b/candle-pyo3/Cargo.toml @@ -19,10 +19,10 @@ candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" } candle-nn = { path = "../candle-nn", version = "0.3.0" } half = { workspace = true } intel-mkl-src = { workspace = true, optional = true } -pyo3 = { version = "0.19.0", features = ["extension-module", "abi3-py38"] } +pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] } [build-dependencies] -pyo3-build-config = "0.19" +pyo3-build-config = "0.20" [features] default = [] diff --git a/candle-pyo3/py_src/candle/__init__.pyi b/candle-pyo3/py_src/candle/__init__.pyi index 43722168..35b17680 100644 --- a/candle-pyo3/py_src/candle/__init__.pyi +++ b/candle-pyo3/py_src/candle/__init__.pyi @@ -1,7 +1,7 @@ # Generated content DO NOT EDIT from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence from os import PathLike -from candle.typing import _ArrayLike, Device, Scalar, Index +from candle.typing import _ArrayLike, Device, Scalar, Index, Shape class bf16(DType): pass @@ -26,21 +26,21 @@ class i64(DType): pass @staticmethod -def ones(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor: +def ones(*shape: Shape, dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor: """ Creates a new tensor filled with ones. """ pass @staticmethod -def rand(shape: Sequence[int], device: Optional[Device] = None) -> Tensor: +def rand(*shape: Shape, device: Optional[Device] = None) -> Tensor: """ Creates a new tensor with random values. """ pass @staticmethod -def randn(shape: Sequence[int], device: Optional[Device] = None) -> Tensor: +def randn(*shape: Shape, device: Optional[Device] = None) -> Tensor: """ Creates a new tensor with random values from a normal distribution. """ @@ -67,7 +67,7 @@ class u8(DType): pass @staticmethod -def zeros(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor: +def zeros(*shape: Shape, dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor: """ Creates a new tensor filled with zeros. """ @@ -174,7 +174,7 @@ class Tensor: Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor. """ pass - def broadcast_as(self, shape: Sequence[int]) -> Tensor: + def broadcast_as(self, *shape: Shape) -> Tensor: """ Broadcasts the tensor to the given shape. """ @@ -184,7 +184,7 @@ class Tensor: Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor. """ pass - def broadcast_left(self, shape: Sequence[int]) -> Tensor: + def broadcast_left(self, *shape: Shape) -> Tensor: """ Broadcasts the tensor to the given shape, adding new dimensions on the left. """ @@ -329,7 +329,7 @@ class Tensor: Get the `recip` of the tensor. """ pass - def reshape(self, shape: Sequence[int]) -> Tensor: + def reshape(self, *shape: Shape) -> Tensor: """ Reshapes the tensor to the given shape. """ diff --git a/candle-pyo3/py_src/candle/functional/__init__.pyi b/candle-pyo3/py_src/candle/functional/__init__.pyi index 5bf5c4c3..4f7c2aa6 100644 --- a/candle-pyo3/py_src/candle/functional/__init__.pyi +++ b/candle-pyo3/py_src/candle/functional/__init__.pyi @@ -1,7 +1,7 @@ # Generated content DO NOT EDIT from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence from os import PathLike -from candle.typing import _ArrayLike, Device, Scalar, Index +from candle.typing import _ArrayLike, Device, Scalar, Index, Shape from candle import Tensor, DType, QTensor @staticmethod diff --git a/candle-pyo3/py_src/candle/typing/__init__.py b/candle-pyo3/py_src/candle/typing/__init__.py index 66bc3d8a..b2262a97 100644 --- a/candle-pyo3/py_src/candle/typing/__init__.py +++ b/candle-pyo3/py_src/candle/typing/__init__.py @@ -18,3 +18,5 @@ Device = TypeVar("Device", CPU, CUDA) Scalar = Union[int, float] Index = Union[int, slice, None, "Ellipsis"] + +Shape = Union[int, Sequence[int]] diff --git a/candle-pyo3/py_src/candle/utils/__init__.pyi b/candle-pyo3/py_src/candle/utils/__init__.pyi index d3b93766..4ee51c29 100644 --- a/candle-pyo3/py_src/candle/utils/__init__.pyi +++ b/candle-pyo3/py_src/candle/utils/__init__.pyi @@ -1,7 +1,7 @@ # Generated content DO NOT EDIT from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence from os import PathLike -from candle.typing import _ArrayLike, Device, Scalar, Index +from candle.typing import _ArrayLike, Device, Scalar, Index, Shape from candle import Tensor, DType, QTensor @staticmethod diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 6d4de80b..41c4577f 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -16,27 +16,14 @@ extern crate accelerate_src; use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType}; +mod shape; +use shape::{PyShape, PyShapeWithHole}; + pub fn wrap_err(err: ::candle::Error) -> PyErr { PyErr::new::<PyValueError, _>(format!("{err:?}")) } #[derive(Clone, Debug)] -struct PyShape(Vec<usize>); - -impl<'source> pyo3::FromPyObject<'source> for PyShape { - fn extract(ob: &'source PyAny) -> PyResult<Self> { - let dims: Vec<usize> = pyo3::FromPyObject::extract(ob)?; - Ok(PyShape(dims)) - } -} - -impl From<PyShape> for ::candle::Shape { - fn from(val: PyShape) -> Self { - val.0.into() - } -} - -#[derive(Clone, Debug)] #[pyclass(name = "Tensor")] /// A `candle` tensor. struct PyTensor(Tensor); @@ -684,25 +671,37 @@ impl PyTensor { Ok(Self(tensor)) } - #[pyo3(text_signature = "(self, shape:Sequence[int])")] + #[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")] /// Reshapes the tensor to the given shape. /// &RETURNS&: Tensor - fn reshape(&self, shape: PyShape) -> PyResult<Self> { - Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?)) + fn reshape(&self, shape: PyShapeWithHole) -> PyResult<Self> { + Ok(PyTensor( + self.0 + .reshape(shape.to_absolute(&self.0)?) + .map_err(wrap_err)?, + )) } - #[pyo3(text_signature = "(self, shape:Sequence[int])")] + #[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")] /// Broadcasts the tensor to the given shape. /// &RETURNS&: Tensor - fn broadcast_as(&self, shape: PyShape) -> PyResult<Self> { - Ok(PyTensor(self.0.broadcast_as(shape).map_err(wrap_err)?)) + fn broadcast_as(&self, shape: PyShapeWithHole) -> PyResult<Self> { + Ok(PyTensor( + self.0 + .broadcast_as(shape.to_absolute(&self.0)?) + .map_err(wrap_err)?, + )) } - #[pyo3(text_signature = "(self, shape:Sequence[int])")] + #[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")] /// Broadcasts the tensor to the given shape, adding new dimensions on the left. /// &RETURNS&: Tensor - fn broadcast_left(&self, shape: PyShape) -> PyResult<Self> { - Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?)) + fn broadcast_left(&self, shape: PyShapeWithHole) -> PyResult<Self> { + Ok(PyTensor( + self.0 + .broadcast_left(shape.to_absolute(&self.0)?) + .map_err(wrap_err)?, + )) } #[pyo3(text_signature = "(self, dim:int)")] @@ -915,21 +914,21 @@ impl PyTensor { } if let Some(kwargs) = kwargs { - if let Some(any) = kwargs.get_item("dtype") { + if let Ok(Some(any)) = kwargs.get_item("dtype") { handle_duplicates( &mut dtype, any.extract::<PyDType>(), "cannot specify multiple dtypes", )?; } - if let Some(any) = kwargs.get_item("device") { + if let Ok(Some(any)) = kwargs.get_item("device") { handle_duplicates( &mut device, any.extract::<PyDevice>(), "cannot specify multiple devices", )?; } - if let Some(any) = kwargs.get_item("other") { + if let Ok(Some(any)) = kwargs.get_item("other") { handle_duplicates( &mut other, any.extract::<PyTensor>(), @@ -1049,27 +1048,27 @@ fn tensor(py: Python<'_>, data: PyObject) -> PyResult<PyTensor> { } #[pyfunction] -#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")] +#[pyo3(signature = (*shape,device=None), text_signature = "(*shape:Shape, device:Optional[Device]=None)")] /// Creates a new tensor with random values. /// &RETURNS&: Tensor fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> { let device = device.unwrap_or(PyDevice::Cpu).as_device()?; - let tensor = Tensor::rand(0f32, 1f32, shape.0, &device).map_err(wrap_err)?; + let tensor = Tensor::rand(0f32, 1f32, shape, &device).map_err(wrap_err)?; Ok(PyTensor(tensor)) } #[pyfunction] -#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")] +#[pyo3(signature = (*shape,device=None), text_signature = "(*shape:Shape, device:Optional[Device]=None)")] /// Creates a new tensor with random values from a normal distribution. /// &RETURNS&: Tensor fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> { let device = device.unwrap_or(PyDevice::Cpu).as_device()?; - let tensor = Tensor::randn(0f32, 1f32, shape.0, &device).map_err(wrap_err)?; + let tensor = Tensor::randn(0f32, 1f32, shape, &device).map_err(wrap_err)?; Ok(PyTensor(tensor)) } #[pyfunction] -#[pyo3(signature = (shape, *, dtype=None, device=None),text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")] +#[pyo3(signature = (*shape, dtype=None, device=None),text_signature = "(*shape:Shape, dtype:Optional[DType]=None, device:Optional[Device]=None)")] /// Creates a new tensor filled with ones. /// &RETURNS&: Tensor fn ones( @@ -1083,12 +1082,12 @@ fn ones( Some(dtype) => PyDType::from_pyobject(dtype, py)?.0, }; let device = device.unwrap_or(PyDevice::Cpu).as_device()?; - let tensor = Tensor::ones(shape.0, dtype, &device).map_err(wrap_err)?; + let tensor = Tensor::ones(shape, dtype, &device).map_err(wrap_err)?; Ok(PyTensor(tensor)) } #[pyfunction] -#[pyo3(signature = (shape, *, dtype=None, device=None), text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")] +#[pyo3(signature = (*shape, dtype=None, device=None), text_signature = "(*shape:Shape, dtype:Optional[DType]=None, device:Optional[Device]=None)")] /// Creates a new tensor filled with zeros. /// &RETURNS&: Tensor fn zeros( @@ -1102,7 +1101,7 @@ fn zeros( Some(dtype) => PyDType::from_pyobject(dtype, py)?.0, }; let device = device.unwrap_or(PyDevice::Cpu).as_device()?; - let tensor = Tensor::zeros(shape.0, dtype, &device).map_err(wrap_err)?; + let tensor = Tensor::zeros(shape, dtype, &device).map_err(wrap_err)?; Ok(PyTensor(tensor)) } diff --git a/candle-pyo3/src/shape.rs b/candle-pyo3/src/shape.rs new file mode 100644 index 00000000..2668b733 --- /dev/null +++ b/candle-pyo3/src/shape.rs @@ -0,0 +1,99 @@ +use ::candle::Tensor; +use pyo3::prelude::*; + +#[derive(Clone, Debug)] +/// Represents an absolute shape e.g. (1, 2, 3) +pub struct PyShape(Vec<usize>); + +impl<'source> pyo3::FromPyObject<'source> for PyShape { + fn extract(ob: &'source PyAny) -> PyResult<Self> { + if ob.is_none() { + return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>( + "Shape cannot be None", + )); + } + + let tuple = ob.downcast::<pyo3::types::PyTuple>()?; + if tuple.len() == 1 { + let first_element = tuple.get_item(0)?; + let dims: Vec<usize> = pyo3::FromPyObject::extract(first_element)?; + Ok(PyShape(dims)) + } else { + let dims: Vec<usize> = pyo3::FromPyObject::extract(tuple)?; + Ok(PyShape(dims)) + } + } +} + +impl From<PyShape> for ::candle::Shape { + fn from(val: PyShape) -> Self { + val.0.into() + } +} + +#[derive(Clone, Debug)] +/// Represents a shape with a hole in it e.g. (1, -1, 3) +pub struct PyShapeWithHole(Vec<isize>); + +impl<'source> pyo3::FromPyObject<'source> for PyShapeWithHole { + fn extract(ob: &'source PyAny) -> PyResult<Self> { + if ob.is_none() { + return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>( + "Shape cannot be None", + )); + } + + let tuple = ob.downcast::<pyo3::types::PyTuple>()?; + let dims: Vec<isize> = if tuple.len() == 1 { + let first_element = tuple.get_item(0)?; + pyo3::FromPyObject::extract(first_element)? + } else { + pyo3::FromPyObject::extract(tuple)? + }; + + // Ensure we have only positive numbers and at most one "hole" (-1) + let negative_ones = dims.iter().filter(|&&x| x == -1).count(); + let any_invalid_dimensions = dims.iter().any(|&x| x < -1 || x == 0); + if negative_ones > 1 || any_invalid_dimensions { + return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!( + "Invalid dimension in shape: {:?}", + dims + ))); + } + + Ok(PyShapeWithHole(dims)) + } +} + +impl PyShapeWithHole { + /// Returns `true` if the shape is absolute e.g. (1, 2, 3) + pub fn is_absolute(&self) -> bool { + self.0.iter().all(|x| *x > 0) + } + + /// Convert a relative shape to an absolute shape e.g. (1, -1) -> (1, 12) + pub fn to_absolute(&self, t: &Tensor) -> PyResult<PyShape> { + if self.is_absolute() { + return Ok(PyShape( + self.0.iter().map(|x| *x as usize).collect::<Vec<usize>>(), + )); + } + + let mut elements = t.elem_count(); + let mut new_dims: Vec<usize> = vec![]; + for dim in self.0.iter() { + if *dim > 0 { + new_dims.push(*dim as usize); + elements /= *dim as usize; + } else if *dim == -1 { + new_dims.push(elements); + } else { + return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!( + "Invalid dimension in shape: {}", + dim + ))); + } + } + Ok(PyShape(new_dims)) + } +} diff --git a/candle-pyo3/stub.py b/candle-pyo3/stub.py index 8e4318bc..336f674b 100644 --- a/candle-pyo3/stub.py +++ b/candle-pyo3/stub.py @@ -13,7 +13,7 @@ GENERATED_COMMENT = "# Generated content DO NOT EDIT\n" TYPING = """from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence from os import PathLike """ -CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device, Scalar, Index\n" +CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device, Scalar, Index, Shape\n" CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType,QTensor\n" RETURN_TYPE_MARKER = "&RETURNS&: " ADDITIONAL_TYPEHINTS = {} diff --git a/candle-pyo3/tests/native/test_shape.py b/candle-pyo3/tests/native/test_shape.py new file mode 100644 index 00000000..864e24d6 --- /dev/null +++ b/candle-pyo3/tests/native/test_shape.py @@ -0,0 +1,31 @@ +from candle import Tensor +from candle import rand +import pytest + + +def test_absolute_shapes_are_valid(): + a = rand((10, 20)) + assert a.shape == (10, 20) + + b = rand(10, 20) + assert b.shape == (10, 20) + pytest.raises(OverflowError, lambda: rand((10, 20, -1))) + pytest.raises(OverflowError, lambda: rand(-1, 20)) + pytest.raises(TypeError, lambda: rand("foo", True)) + + +def test_relative_shapes_are_valid(): + a = rand(10, 20) + a = a.reshape((1, -1)) + assert a.shape == (1, 200) + + b = rand(10, 20) + b = b.reshape(-1, 1) + assert b.shape == (200, 1) + + c = rand(10, 20) + pytest.raises(TypeError, lambda: c.reshape(1, "foo")) + pytest.raises(ValueError, lambda: c.reshape(1, -2)) + pytest.raises(ValueError, lambda: c.reshape((-2, 1))) + pytest.raises(ValueError, lambda: c.reshape((0, 1))) + pytest.raises(ValueError, lambda: c.reshape((1, -1, -1))) |