summaryrefslogtreecommitdiff
path: root/candle-pyo3
diff options
context:
space:
mode:
authorLukas Kreussel <65088241+LLukas22@users.noreply.github.com>2023-10-29 16:41:44 +0100
committerGitHub <noreply@github.com>2023-10-29 15:41:44 +0000
commit174b20805230abaf91b838598d84ab142f31a975 (patch)
tree775e1ca849f2e9f0ac2a32df476106ec1bb3a5dc /candle-pyo3
parent154c674a798fd5a40d57ff9a8664856d9c41ca56 (diff)
downloadcandle-174b20805230abaf91b838598d84ab142f31a975.tar.gz
candle-174b20805230abaf91b838598d84ab142f31a975.tar.bz2
candle-174b20805230abaf91b838598d84ab142f31a975.zip
PyO3: Better shape handling (#1143)
* Negative and `*args` shape handling * Rename to `PyShapeWithHole` + validate that only one hole exists * Regenerate stubs --------- Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-pyo3')
-rw-r--r--candle-pyo3/Cargo.toml4
-rw-r--r--candle-pyo3/py_src/candle/__init__.pyi16
-rw-r--r--candle-pyo3/py_src/candle/functional/__init__.pyi2
-rw-r--r--candle-pyo3/py_src/candle/typing/__init__.py2
-rw-r--r--candle-pyo3/py_src/candle/utils/__init__.pyi2
-rw-r--r--candle-pyo3/src/lib.rs71
-rw-r--r--candle-pyo3/src/shape.rs99
-rw-r--r--candle-pyo3/stub.py2
-rw-r--r--candle-pyo3/tests/native/test_shape.py31
9 files changed, 180 insertions, 49 deletions
diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml
index 488404bf..b0452404 100644
--- a/candle-pyo3/Cargo.toml
+++ b/candle-pyo3/Cargo.toml
@@ -19,10 +19,10 @@ candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.0" }
half = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
-pyo3 = { version = "0.19.0", features = ["extension-module", "abi3-py38"] }
+pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }
[build-dependencies]
-pyo3-build-config = "0.19"
+pyo3-build-config = "0.20"
[features]
default = []
diff --git a/candle-pyo3/py_src/candle/__init__.pyi b/candle-pyo3/py_src/candle/__init__.pyi
index 43722168..35b17680 100644
--- a/candle-pyo3/py_src/candle/__init__.pyi
+++ b/candle-pyo3/py_src/candle/__init__.pyi
@@ -1,7 +1,7 @@
# Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
-from candle.typing import _ArrayLike, Device, Scalar, Index
+from candle.typing import _ArrayLike, Device, Scalar, Index, Shape
class bf16(DType):
pass
@@ -26,21 +26,21 @@ class i64(DType):
pass
@staticmethod
-def ones(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor:
+def ones(*shape: Shape, dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor:
"""
Creates a new tensor filled with ones.
"""
pass
@staticmethod
-def rand(shape: Sequence[int], device: Optional[Device] = None) -> Tensor:
+def rand(*shape: Shape, device: Optional[Device] = None) -> Tensor:
"""
Creates a new tensor with random values.
"""
pass
@staticmethod
-def randn(shape: Sequence[int], device: Optional[Device] = None) -> Tensor:
+def randn(*shape: Shape, device: Optional[Device] = None) -> Tensor:
"""
Creates a new tensor with random values from a normal distribution.
"""
@@ -67,7 +67,7 @@ class u8(DType):
pass
@staticmethod
-def zeros(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor:
+def zeros(*shape: Shape, dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor:
"""
Creates a new tensor filled with zeros.
"""
@@ -174,7 +174,7 @@ class Tensor:
Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
"""
pass
- def broadcast_as(self, shape: Sequence[int]) -> Tensor:
+ def broadcast_as(self, *shape: Shape) -> Tensor:
"""
Broadcasts the tensor to the given shape.
"""
@@ -184,7 +184,7 @@ class Tensor:
Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
"""
pass
- def broadcast_left(self, shape: Sequence[int]) -> Tensor:
+ def broadcast_left(self, *shape: Shape) -> Tensor:
"""
Broadcasts the tensor to the given shape, adding new dimensions on the left.
"""
@@ -329,7 +329,7 @@ class Tensor:
Get the `recip` of the tensor.
"""
pass
- def reshape(self, shape: Sequence[int]) -> Tensor:
+ def reshape(self, *shape: Shape) -> Tensor:
"""
Reshapes the tensor to the given shape.
"""
diff --git a/candle-pyo3/py_src/candle/functional/__init__.pyi b/candle-pyo3/py_src/candle/functional/__init__.pyi
index 5bf5c4c3..4f7c2aa6 100644
--- a/candle-pyo3/py_src/candle/functional/__init__.pyi
+++ b/candle-pyo3/py_src/candle/functional/__init__.pyi
@@ -1,7 +1,7 @@
# Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
-from candle.typing import _ArrayLike, Device, Scalar, Index
+from candle.typing import _ArrayLike, Device, Scalar, Index, Shape
from candle import Tensor, DType, QTensor
@staticmethod
diff --git a/candle-pyo3/py_src/candle/typing/__init__.py b/candle-pyo3/py_src/candle/typing/__init__.py
index 66bc3d8a..b2262a97 100644
--- a/candle-pyo3/py_src/candle/typing/__init__.py
+++ b/candle-pyo3/py_src/candle/typing/__init__.py
@@ -18,3 +18,5 @@ Device = TypeVar("Device", CPU, CUDA)
Scalar = Union[int, float]
Index = Union[int, slice, None, "Ellipsis"]
+
+Shape = Union[int, Sequence[int]]
diff --git a/candle-pyo3/py_src/candle/utils/__init__.pyi b/candle-pyo3/py_src/candle/utils/__init__.pyi
index d3b93766..4ee51c29 100644
--- a/candle-pyo3/py_src/candle/utils/__init__.pyi
+++ b/candle-pyo3/py_src/candle/utils/__init__.pyi
@@ -1,7 +1,7 @@
# Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
-from candle.typing import _ArrayLike, Device, Scalar, Index
+from candle.typing import _ArrayLike, Device, Scalar, Index, Shape
from candle import Tensor, DType, QTensor
@staticmethod
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index 6d4de80b..41c4577f 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -16,27 +16,14 @@ extern crate accelerate_src;
use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};
+mod shape;
+use shape::{PyShape, PyShapeWithHole};
+
pub fn wrap_err(err: ::candle::Error) -> PyErr {
PyErr::new::<PyValueError, _>(format!("{err:?}"))
}
#[derive(Clone, Debug)]
-struct PyShape(Vec<usize>);
-
-impl<'source> pyo3::FromPyObject<'source> for PyShape {
- fn extract(ob: &'source PyAny) -> PyResult<Self> {
- let dims: Vec<usize> = pyo3::FromPyObject::extract(ob)?;
- Ok(PyShape(dims))
- }
-}
-
-impl From<PyShape> for ::candle::Shape {
- fn from(val: PyShape) -> Self {
- val.0.into()
- }
-}
-
-#[derive(Clone, Debug)]
#[pyclass(name = "Tensor")]
/// A `candle` tensor.
struct PyTensor(Tensor);
@@ -684,25 +671,37 @@ impl PyTensor {
Ok(Self(tensor))
}
- #[pyo3(text_signature = "(self, shape:Sequence[int])")]
+ #[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
/// Reshapes the tensor to the given shape.
/// &RETURNS&: Tensor
- fn reshape(&self, shape: PyShape) -> PyResult<Self> {
- Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?))
+ fn reshape(&self, shape: PyShapeWithHole) -> PyResult<Self> {
+ Ok(PyTensor(
+ self.0
+ .reshape(shape.to_absolute(&self.0)?)
+ .map_err(wrap_err)?,
+ ))
}
- #[pyo3(text_signature = "(self, shape:Sequence[int])")]
+ #[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
/// Broadcasts the tensor to the given shape.
/// &RETURNS&: Tensor
- fn broadcast_as(&self, shape: PyShape) -> PyResult<Self> {
- Ok(PyTensor(self.0.broadcast_as(shape).map_err(wrap_err)?))
+ fn broadcast_as(&self, shape: PyShapeWithHole) -> PyResult<Self> {
+ Ok(PyTensor(
+ self.0
+ .broadcast_as(shape.to_absolute(&self.0)?)
+ .map_err(wrap_err)?,
+ ))
}
- #[pyo3(text_signature = "(self, shape:Sequence[int])")]
+ #[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
/// Broadcasts the tensor to the given shape, adding new dimensions on the left.
/// &RETURNS&: Tensor
- fn broadcast_left(&self, shape: PyShape) -> PyResult<Self> {
- Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?))
+ fn broadcast_left(&self, shape: PyShapeWithHole) -> PyResult<Self> {
+ Ok(PyTensor(
+ self.0
+ .broadcast_left(shape.to_absolute(&self.0)?)
+ .map_err(wrap_err)?,
+ ))
}
#[pyo3(text_signature = "(self, dim:int)")]
@@ -915,21 +914,21 @@ impl PyTensor {
}
if let Some(kwargs) = kwargs {
- if let Some(any) = kwargs.get_item("dtype") {
+ if let Ok(Some(any)) = kwargs.get_item("dtype") {
handle_duplicates(
&mut dtype,
any.extract::<PyDType>(),
"cannot specify multiple dtypes",
)?;
}
- if let Some(any) = kwargs.get_item("device") {
+ if let Ok(Some(any)) = kwargs.get_item("device") {
handle_duplicates(
&mut device,
any.extract::<PyDevice>(),
"cannot specify multiple devices",
)?;
}
- if let Some(any) = kwargs.get_item("other") {
+ if let Ok(Some(any)) = kwargs.get_item("other") {
handle_duplicates(
&mut other,
any.extract::<PyTensor>(),
@@ -1049,27 +1048,27 @@ fn tensor(py: Python<'_>, data: PyObject) -> PyResult<PyTensor> {
}
#[pyfunction]
-#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")]
+#[pyo3(signature = (*shape,device=None), text_signature = "(*shape:Shape, device:Optional[Device]=None)")]
/// Creates a new tensor with random values.
/// &RETURNS&: Tensor
fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
- let tensor = Tensor::rand(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
+ let tensor = Tensor::rand(0f32, 1f32, shape, &device).map_err(wrap_err)?;
Ok(PyTensor(tensor))
}
#[pyfunction]
-#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")]
+#[pyo3(signature = (*shape,device=None), text_signature = "(*shape:Shape, device:Optional[Device]=None)")]
/// Creates a new tensor with random values from a normal distribution.
/// &RETURNS&: Tensor
fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
- let tensor = Tensor::randn(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
+ let tensor = Tensor::randn(0f32, 1f32, shape, &device).map_err(wrap_err)?;
Ok(PyTensor(tensor))
}
#[pyfunction]
-#[pyo3(signature = (shape, *, dtype=None, device=None),text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")]
+#[pyo3(signature = (*shape, dtype=None, device=None),text_signature = "(*shape:Shape, dtype:Optional[DType]=None, device:Optional[Device]=None)")]
/// Creates a new tensor filled with ones.
/// &RETURNS&: Tensor
fn ones(
@@ -1083,12 +1082,12 @@ fn ones(
Some(dtype) => PyDType::from_pyobject(dtype, py)?.0,
};
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
- let tensor = Tensor::ones(shape.0, dtype, &device).map_err(wrap_err)?;
+ let tensor = Tensor::ones(shape, dtype, &device).map_err(wrap_err)?;
Ok(PyTensor(tensor))
}
#[pyfunction]
-#[pyo3(signature = (shape, *, dtype=None, device=None), text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")]
+#[pyo3(signature = (*shape, dtype=None, device=None), text_signature = "(*shape:Shape, dtype:Optional[DType]=None, device:Optional[Device]=None)")]
/// Creates a new tensor filled with zeros.
/// &RETURNS&: Tensor
fn zeros(
@@ -1102,7 +1101,7 @@ fn zeros(
Some(dtype) => PyDType::from_pyobject(dtype, py)?.0,
};
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
- let tensor = Tensor::zeros(shape.0, dtype, &device).map_err(wrap_err)?;
+ let tensor = Tensor::zeros(shape, dtype, &device).map_err(wrap_err)?;
Ok(PyTensor(tensor))
}
diff --git a/candle-pyo3/src/shape.rs b/candle-pyo3/src/shape.rs
new file mode 100644
index 00000000..2668b733
--- /dev/null
+++ b/candle-pyo3/src/shape.rs
@@ -0,0 +1,99 @@
+use ::candle::Tensor;
+use pyo3::prelude::*;
+
+#[derive(Clone, Debug)]
+/// Represents an absolute shape e.g. (1, 2, 3)
+pub struct PyShape(Vec<usize>);
+
+impl<'source> pyo3::FromPyObject<'source> for PyShape {
+ fn extract(ob: &'source PyAny) -> PyResult<Self> {
+ if ob.is_none() {
+ return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
+ "Shape cannot be None",
+ ));
+ }
+
+ let tuple = ob.downcast::<pyo3::types::PyTuple>()?;
+ if tuple.len() == 1 {
+ let first_element = tuple.get_item(0)?;
+ let dims: Vec<usize> = pyo3::FromPyObject::extract(first_element)?;
+ Ok(PyShape(dims))
+ } else {
+ let dims: Vec<usize> = pyo3::FromPyObject::extract(tuple)?;
+ Ok(PyShape(dims))
+ }
+ }
+}
+
+impl From<PyShape> for ::candle::Shape {
+ fn from(val: PyShape) -> Self {
+ val.0.into()
+ }
+}
+
+#[derive(Clone, Debug)]
+/// Represents a shape with a hole in it e.g. (1, -1, 3)
+pub struct PyShapeWithHole(Vec<isize>);
+
+impl<'source> pyo3::FromPyObject<'source> for PyShapeWithHole {
+ fn extract(ob: &'source PyAny) -> PyResult<Self> {
+ if ob.is_none() {
+ return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
+ "Shape cannot be None",
+ ));
+ }
+
+ let tuple = ob.downcast::<pyo3::types::PyTuple>()?;
+ let dims: Vec<isize> = if tuple.len() == 1 {
+ let first_element = tuple.get_item(0)?;
+ pyo3::FromPyObject::extract(first_element)?
+ } else {
+ pyo3::FromPyObject::extract(tuple)?
+ };
+
+ // Ensure we have only positive numbers and at most one "hole" (-1)
+ let negative_ones = dims.iter().filter(|&&x| x == -1).count();
+ let any_invalid_dimensions = dims.iter().any(|&x| x < -1 || x == 0);
+ if negative_ones > 1 || any_invalid_dimensions {
+ return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
+ "Invalid dimension in shape: {:?}",
+ dims
+ )));
+ }
+
+ Ok(PyShapeWithHole(dims))
+ }
+}
+
+impl PyShapeWithHole {
+ /// Returns `true` if the shape is absolute e.g. (1, 2, 3)
+ pub fn is_absolute(&self) -> bool {
+ self.0.iter().all(|x| *x > 0)
+ }
+
+ /// Convert a relative shape to an absolute shape e.g. (1, -1) -> (1, 12)
+ pub fn to_absolute(&self, t: &Tensor) -> PyResult<PyShape> {
+ if self.is_absolute() {
+ return Ok(PyShape(
+ self.0.iter().map(|x| *x as usize).collect::<Vec<usize>>(),
+ ));
+ }
+
+ let mut elements = t.elem_count();
+ let mut new_dims: Vec<usize> = vec![];
+ for dim in self.0.iter() {
+ if *dim > 0 {
+ new_dims.push(*dim as usize);
+ elements /= *dim as usize;
+ } else if *dim == -1 {
+ new_dims.push(elements);
+ } else {
+ return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
+ "Invalid dimension in shape: {}",
+ dim
+ )));
+ }
+ }
+ Ok(PyShape(new_dims))
+ }
+}
diff --git a/candle-pyo3/stub.py b/candle-pyo3/stub.py
index 8e4318bc..336f674b 100644
--- a/candle-pyo3/stub.py
+++ b/candle-pyo3/stub.py
@@ -13,7 +13,7 @@ GENERATED_COMMENT = "# Generated content DO NOT EDIT\n"
TYPING = """from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
"""
-CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device, Scalar, Index\n"
+CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device, Scalar, Index, Shape\n"
CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType,QTensor\n"
RETURN_TYPE_MARKER = "&RETURNS&: "
ADDITIONAL_TYPEHINTS = {}
diff --git a/candle-pyo3/tests/native/test_shape.py b/candle-pyo3/tests/native/test_shape.py
new file mode 100644
index 00000000..864e24d6
--- /dev/null
+++ b/candle-pyo3/tests/native/test_shape.py
@@ -0,0 +1,31 @@
+from candle import Tensor
+from candle import rand
+import pytest
+
+
+def test_absolute_shapes_are_valid():
+ a = rand((10, 20))
+ assert a.shape == (10, 20)
+
+ b = rand(10, 20)
+ assert b.shape == (10, 20)
+ pytest.raises(OverflowError, lambda: rand((10, 20, -1)))
+ pytest.raises(OverflowError, lambda: rand(-1, 20))
+ pytest.raises(TypeError, lambda: rand("foo", True))
+
+
+def test_relative_shapes_are_valid():
+ a = rand(10, 20)
+ a = a.reshape((1, -1))
+ assert a.shape == (1, 200)
+
+ b = rand(10, 20)
+ b = b.reshape(-1, 1)
+ assert b.shape == (200, 1)
+
+ c = rand(10, 20)
+ pytest.raises(TypeError, lambda: c.reshape(1, "foo"))
+ pytest.raises(ValueError, lambda: c.reshape(1, -2))
+ pytest.raises(ValueError, lambda: c.reshape((-2, 1)))
+ pytest.raises(ValueError, lambda: c.reshape((0, 1)))
+ pytest.raises(ValueError, lambda: c.reshape((1, -1, -1)))