summaryrefslogtreecommitdiff
path: root/candle-pyo3/src/shape.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-pyo3/src/shape.rs')
-rw-r--r--candle-pyo3/src/shape.rs12
1 files changed, 6 insertions, 6 deletions
diff --git a/candle-pyo3/src/shape.rs b/candle-pyo3/src/shape.rs
index 2668b733..b9bc6789 100644
--- a/candle-pyo3/src/shape.rs
+++ b/candle-pyo3/src/shape.rs
@@ -6,7 +6,7 @@ use pyo3::prelude::*;
pub struct PyShape(Vec<usize>);
impl<'source> pyo3::FromPyObject<'source> for PyShape {
- fn extract(ob: &'source PyAny) -> PyResult<Self> {
+ fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult<Self> {
if ob.is_none() {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Shape cannot be None",
@@ -16,10 +16,10 @@ impl<'source> pyo3::FromPyObject<'source> for PyShape {
let tuple = ob.downcast::<pyo3::types::PyTuple>()?;
if tuple.len() == 1 {
let first_element = tuple.get_item(0)?;
- let dims: Vec<usize> = pyo3::FromPyObject::extract(first_element)?;
+ let dims: Vec<usize> = pyo3::FromPyObject::extract_bound(&first_element)?;
Ok(PyShape(dims))
} else {
- let dims: Vec<usize> = pyo3::FromPyObject::extract(tuple)?;
+ let dims: Vec<usize> = pyo3::FromPyObject::extract_bound(tuple)?;
Ok(PyShape(dims))
}
}
@@ -36,7 +36,7 @@ impl From<PyShape> for ::candle::Shape {
pub struct PyShapeWithHole(Vec<isize>);
impl<'source> pyo3::FromPyObject<'source> for PyShapeWithHole {
- fn extract(ob: &'source PyAny) -> PyResult<Self> {
+ fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult<Self> {
if ob.is_none() {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Shape cannot be None",
@@ -46,9 +46,9 @@ impl<'source> pyo3::FromPyObject<'source> for PyShapeWithHole {
let tuple = ob.downcast::<pyo3::types::PyTuple>()?;
let dims: Vec<isize> = if tuple.len() == 1 {
let first_element = tuple.get_item(0)?;
- pyo3::FromPyObject::extract(first_element)?
+ pyo3::FromPyObject::extract_bound(&first_element)?
} else {
- pyo3::FromPyObject::extract(tuple)?
+ pyo3::FromPyObject::extract_bound(tuple)?
};
// Ensure we have only positive numbers and at most one "hole" (-1)