diff options
author | Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> | 2023-10-29 16:41:44 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-29 15:41:44 +0000 |
commit | 174b20805230abaf91b838598d84ab142f31a975 (patch) | |
tree | 775e1ca849f2e9f0ac2a32df476106ec1bb3a5dc /candle-pyo3/src/shape.rs | |
parent | 154c674a798fd5a40d57ff9a8664856d9c41ca56 (diff) | |
download | candle-174b20805230abaf91b838598d84ab142f31a975.tar.gz candle-174b20805230abaf91b838598d84ab142f31a975.tar.bz2 candle-174b20805230abaf91b838598d84ab142f31a975.zip |
PyO3: Better shape handling (#1143)
* Negative and `*args` shape handling
* Rename to `PyShapeWithHole` + validate that only one hole exists
* Regenerate stubs
---------
Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-pyo3/src/shape.rs')
-rw-r--r-- | candle-pyo3/src/shape.rs | 99 |
1 files changed, 99 insertions, 0 deletions
diff --git a/candle-pyo3/src/shape.rs b/candle-pyo3/src/shape.rs new file mode 100644 index 00000000..2668b733 --- /dev/null +++ b/candle-pyo3/src/shape.rs @@ -0,0 +1,99 @@ +use ::candle::Tensor; +use pyo3::prelude::*; + +#[derive(Clone, Debug)] +/// Represents an absolute shape e.g. (1, 2, 3) +pub struct PyShape(Vec<usize>); + +impl<'source> pyo3::FromPyObject<'source> for PyShape { + fn extract(ob: &'source PyAny) -> PyResult<Self> { + if ob.is_none() { + return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>( + "Shape cannot be None", + )); + } + + let tuple = ob.downcast::<pyo3::types::PyTuple>()?; + if tuple.len() == 1 { + let first_element = tuple.get_item(0)?; + let dims: Vec<usize> = pyo3::FromPyObject::extract(first_element)?; + Ok(PyShape(dims)) + } else { + let dims: Vec<usize> = pyo3::FromPyObject::extract(tuple)?; + Ok(PyShape(dims)) + } + } +} + +impl From<PyShape> for ::candle::Shape { + fn from(val: PyShape) -> Self { + val.0.into() + } +} + +#[derive(Clone, Debug)] +/// Represents a shape with a hole in it e.g. (1, -1, 3) +pub struct PyShapeWithHole(Vec<isize>); + +impl<'source> pyo3::FromPyObject<'source> for PyShapeWithHole { + fn extract(ob: &'source PyAny) -> PyResult<Self> { + if ob.is_none() { + return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>( + "Shape cannot be None", + )); + } + + let tuple = ob.downcast::<pyo3::types::PyTuple>()?; + let dims: Vec<isize> = if tuple.len() == 1 { + let first_element = tuple.get_item(0)?; + pyo3::FromPyObject::extract(first_element)? + } else { + pyo3::FromPyObject::extract(tuple)? + }; + + // Ensure we have only positive numbers and at most one "hole" (-1) + let negative_ones = dims.iter().filter(|&&x| x == -1).count(); + let any_invalid_dimensions = dims.iter().any(|&x| x < -1 || x == 0); + if negative_ones > 1 || any_invalid_dimensions { + return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!( + "Invalid dimension in shape: {:?}", + dims + ))); + } + + Ok(PyShapeWithHole(dims)) + } +} + +impl PyShapeWithHole { + /// Returns `true` if the shape is absolute e.g. (1, 2, 3) + pub fn is_absolute(&self) -> bool { + self.0.iter().all(|x| *x > 0) + } + + /// Convert a relative shape to an absolute shape e.g. (1, -1) -> (1, 12) + pub fn to_absolute(&self, t: &Tensor) -> PyResult<PyShape> { + if self.is_absolute() { + return Ok(PyShape( + self.0.iter().map(|x| *x as usize).collect::<Vec<usize>>(), + )); + } + + let mut elements = t.elem_count(); + let mut new_dims: Vec<usize> = vec![]; + for dim in self.0.iter() { + if *dim > 0 { + new_dims.push(*dim as usize); + elements /= *dim as usize; + } else if *dim == -1 { + new_dims.push(elements); + } else { + return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!( + "Invalid dimension in shape: {}", + dim + ))); + } + } + Ok(PyShape(new_dims)) + } +} |