diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-10-06 10:09:38 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-10-06 10:09:38 +0200 |
commit | f856b5c3a75028d384c26e36501d429091662cd3 (patch) | |
tree | 811569350d124b23bb54f5381756ec7cd0b34278 /candle-pyo3/src/shape.rs | |
parent | d2e432914ec495baff1db29799fe316b9190b0e9 (diff) | |
download | candle-f856b5c3a75028d384c26e36501d429091662cd3.tar.gz candle-f856b5c3a75028d384c26e36501d429091662cd3.tar.bz2 candle-f856b5c3a75028d384c26e36501d429091662cd3.zip |
pyo3 update. (#2545)
* pyo3 update.
* Stub fix.
Diffstat (limited to 'candle-pyo3/src/shape.rs')
-rw-r--r-- | candle-pyo3/src/shape.rs | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/candle-pyo3/src/shape.rs b/candle-pyo3/src/shape.rs index 2668b733..b9bc6789 100644 --- a/candle-pyo3/src/shape.rs +++ b/candle-pyo3/src/shape.rs @@ -6,7 +6,7 @@ use pyo3::prelude::*; pub struct PyShape(Vec<usize>); impl<'source> pyo3::FromPyObject<'source> for PyShape { - fn extract(ob: &'source PyAny) -> PyResult<Self> { + fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult<Self> { if ob.is_none() { return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>( "Shape cannot be None", @@ -16,10 +16,10 @@ impl<'source> pyo3::FromPyObject<'source> for PyShape { let tuple = ob.downcast::<pyo3::types::PyTuple>()?; if tuple.len() == 1 { let first_element = tuple.get_item(0)?; - let dims: Vec<usize> = pyo3::FromPyObject::extract(first_element)?; + let dims: Vec<usize> = pyo3::FromPyObject::extract_bound(&first_element)?; Ok(PyShape(dims)) } else { - let dims: Vec<usize> = pyo3::FromPyObject::extract(tuple)?; + let dims: Vec<usize> = pyo3::FromPyObject::extract_bound(tuple)?; Ok(PyShape(dims)) } } @@ -36,7 +36,7 @@ impl From<PyShape> for ::candle::Shape { pub struct PyShapeWithHole(Vec<isize>); impl<'source> pyo3::FromPyObject<'source> for PyShapeWithHole { - fn extract(ob: &'source PyAny) -> PyResult<Self> { + fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult<Self> { if ob.is_none() { return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>( "Shape cannot be None", @@ -46,9 +46,9 @@ impl<'source> pyo3::FromPyObject<'source> for PyShapeWithHole { let tuple = ob.downcast::<pyo3::types::PyTuple>()?; let dims: Vec<isize> = if tuple.len() == 1 { let first_element = tuple.get_item(0)?; - pyo3::FromPyObject::extract(first_element)? + pyo3::FromPyObject::extract_bound(&first_element)? } else { - pyo3::FromPyObject::extract(tuple)? + pyo3::FromPyObject::extract_bound(tuple)? }; // Ensure we have only positive numbers and at most one "hole" (-1) |