diff options
Diffstat (limited to 'candle-pyo3/src/lib.rs')
-rw-r--r-- | candle-pyo3/src/lib.rs | 71 |
1 files changed, 35 insertions, 36 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 6d4de80b..41c4577f 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -16,27 +16,14 @@ extern crate accelerate_src; use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType}; +mod shape; +use shape::{PyShape, PyShapeWithHole}; + pub fn wrap_err(err: ::candle::Error) -> PyErr { PyErr::new::<PyValueError, _>(format!("{err:?}")) } #[derive(Clone, Debug)] -struct PyShape(Vec<usize>); - -impl<'source> pyo3::FromPyObject<'source> for PyShape { - fn extract(ob: &'source PyAny) -> PyResult<Self> { - let dims: Vec<usize> = pyo3::FromPyObject::extract(ob)?; - Ok(PyShape(dims)) - } -} - -impl From<PyShape> for ::candle::Shape { - fn from(val: PyShape) -> Self { - val.0.into() - } -} - -#[derive(Clone, Debug)] #[pyclass(name = "Tensor")] /// A `candle` tensor. struct PyTensor(Tensor); @@ -684,25 +671,37 @@ impl PyTensor { Ok(Self(tensor)) } - #[pyo3(text_signature = "(self, shape:Sequence[int])")] + #[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")] /// Reshapes the tensor to the given shape. /// &RETURNS&: Tensor - fn reshape(&self, shape: PyShape) -> PyResult<Self> { - Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?)) + fn reshape(&self, shape: PyShapeWithHole) -> PyResult<Self> { + Ok(PyTensor( + self.0 + .reshape(shape.to_absolute(&self.0)?) + .map_err(wrap_err)?, + )) } - #[pyo3(text_signature = "(self, shape:Sequence[int])")] + #[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")] /// Broadcasts the tensor to the given shape. /// &RETURNS&: Tensor - fn broadcast_as(&self, shape: PyShape) -> PyResult<Self> { - Ok(PyTensor(self.0.broadcast_as(shape).map_err(wrap_err)?)) + fn broadcast_as(&self, shape: PyShapeWithHole) -> PyResult<Self> { + Ok(PyTensor( + self.0 + .broadcast_as(shape.to_absolute(&self.0)?) + .map_err(wrap_err)?, + )) } - #[pyo3(text_signature = "(self, shape:Sequence[int])")] + #[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")] /// Broadcasts the tensor to the given shape, adding new dimensions on the left. /// &RETURNS&: Tensor - fn broadcast_left(&self, shape: PyShape) -> PyResult<Self> { - Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?)) + fn broadcast_left(&self, shape: PyShapeWithHole) -> PyResult<Self> { + Ok(PyTensor( + self.0 + .broadcast_left(shape.to_absolute(&self.0)?) + .map_err(wrap_err)?, + )) } #[pyo3(text_signature = "(self, dim:int)")] @@ -915,21 +914,21 @@ impl PyTensor { } if let Some(kwargs) = kwargs { - if let Some(any) = kwargs.get_item("dtype") { + if let Ok(Some(any)) = kwargs.get_item("dtype") { handle_duplicates( &mut dtype, any.extract::<PyDType>(), "cannot specify multiple dtypes", )?; } - if let Some(any) = kwargs.get_item("device") { + if let Ok(Some(any)) = kwargs.get_item("device") { handle_duplicates( &mut device, any.extract::<PyDevice>(), "cannot specify multiple devices", )?; } - if let Some(any) = kwargs.get_item("other") { + if let Ok(Some(any)) = kwargs.get_item("other") { handle_duplicates( &mut other, any.extract::<PyTensor>(), @@ -1049,27 +1048,27 @@ fn tensor(py: Python<'_>, data: PyObject) -> PyResult<PyTensor> { } #[pyfunction] -#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")] +#[pyo3(signature = (*shape,device=None), text_signature = "(*shape:Shape, device:Optional[Device]=None)")] /// Creates a new tensor with random values. /// &RETURNS&: Tensor fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> { let device = device.unwrap_or(PyDevice::Cpu).as_device()?; - let tensor = Tensor::rand(0f32, 1f32, shape.0, &device).map_err(wrap_err)?; + let tensor = Tensor::rand(0f32, 1f32, shape, &device).map_err(wrap_err)?; Ok(PyTensor(tensor)) } #[pyfunction] -#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")] +#[pyo3(signature = (*shape,device=None), text_signature = "(*shape:Shape, device:Optional[Device]=None)")] /// Creates a new tensor with random values from a normal distribution. /// &RETURNS&: Tensor fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> { let device = device.unwrap_or(PyDevice::Cpu).as_device()?; - let tensor = Tensor::randn(0f32, 1f32, shape.0, &device).map_err(wrap_err)?; + let tensor = Tensor::randn(0f32, 1f32, shape, &device).map_err(wrap_err)?; Ok(PyTensor(tensor)) } #[pyfunction] -#[pyo3(signature = (shape, *, dtype=None, device=None),text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")] +#[pyo3(signature = (*shape, dtype=None, device=None),text_signature = "(*shape:Shape, dtype:Optional[DType]=None, device:Optional[Device]=None)")] /// Creates a new tensor filled with ones. /// &RETURNS&: Tensor fn ones( @@ -1083,12 +1082,12 @@ fn ones( Some(dtype) => PyDType::from_pyobject(dtype, py)?.0, }; let device = device.unwrap_or(PyDevice::Cpu).as_device()?; - let tensor = Tensor::ones(shape.0, dtype, &device).map_err(wrap_err)?; + let tensor = Tensor::ones(shape, dtype, &device).map_err(wrap_err)?; Ok(PyTensor(tensor)) } #[pyfunction] -#[pyo3(signature = (shape, *, dtype=None, device=None), text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")] +#[pyo3(signature = (*shape, dtype=None, device=None), text_signature = "(*shape:Shape, dtype:Optional[DType]=None, device:Optional[Device]=None)")] /// Creates a new tensor filled with zeros. /// &RETURNS&: Tensor fn zeros( @@ -1102,7 +1101,7 @@ fn zeros( Some(dtype) => PyDType::from_pyobject(dtype, py)?.0, }; let device = device.unwrap_or(PyDevice::Cpu).as_device()?; - let tensor = Tensor::zeros(shape.0, dtype, &device).map_err(wrap_err)?; + let tensor = Tensor::zeros(shape, dtype, &device).map_err(wrap_err)?; Ok(PyTensor(tensor)) } |