summaryrefslogtreecommitdiff
path: root/candle-pyo3
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-01 17:07:02 +0200
committerGitHub <noreply@github.com>2024-04-01 17:07:02 +0200
commitb20acd622ced28f062d9f91410948282c10661ce (patch)
tree3103031a9db70363fd1aac0d4aa57ebc49804415 /candle-pyo3
parent5522bbc57c2967f3c8fb8fa9ab8a82d2c9ff8db8 (diff)
downloadcandle-b20acd622ced28f062d9f91410948282c10661ce.tar.gz
candle-b20acd622ced28f062d9f91410948282c10661ce.tar.bz2
candle-b20acd622ced28f062d9f91410948282c10661ce.zip
Update for pyo3 0.21. (#1985)
* Update for pyo3 0.21. * Also adapt the RL example. * Fix for the pyo3-onnx bindings... * Print details on failures. * Revert pyi.
Diffstat (limited to 'candle-pyo3')
-rw-r--r--candle-pyo3/Cargo.toml4
-rw-r--r--candle-pyo3/py_src/candle/nn/__init__.pyi19
-rw-r--r--candle-pyo3/src/lib.rs88
-rw-r--r--candle-pyo3/src/onnx.rs2
-rw-r--r--candle-pyo3/stub.py4
5 files changed, 70 insertions, 47 deletions
diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml
index 7c6fbd68..88001334 100644
--- a/candle-pyo3/Cargo.toml
+++ b/candle-pyo3/Cargo.toml
@@ -20,10 +20,10 @@ candle-nn = { workspace = true }
candle-onnx = { workspace = true, optional = true }
half = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
-pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }
+pyo3 = { version = "0.21.0", features = ["extension-module", "abi3-py38"] }
[build-dependencies]
-pyo3-build-config = "0.20"
+pyo3-build-config = "0.21"
[features]
default = []
diff --git a/candle-pyo3/py_src/candle/nn/__init__.pyi b/candle-pyo3/py_src/candle/nn/__init__.pyi
new file mode 100644
index 00000000..118c4cff
--- /dev/null
+++ b/candle-pyo3/py_src/candle/nn/__init__.pyi
@@ -0,0 +1,19 @@
+# Generated content DO NOT EDIT
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
+from os import PathLike
+from candle.typing import _ArrayLike, Device, Scalar, Index, Shape
+from candle import Tensor, DType, QTensor
+
+@staticmethod
+def silu(tensor: Tensor) -> Tensor:
+ """
+ Applies the Sigmoid Linear Unit (SiLU) function to a given tensor.
+ """
+ pass
+
+@staticmethod
+def softmax(tensor: Tensor, dim: int) -> Tensor:
+ """
+ Applies the Softmax function to a given tensor.#
+ """
+ pass
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index e0d3bf30..0da2c700 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -60,8 +60,8 @@ impl PyDType {
impl PyDType {
fn from_pyobject(ob: PyObject, py: Python<'_>) -> PyResult<Self> {
use std::str::FromStr;
- if let Ok(dtype) = ob.extract::<&str>(py) {
- let dtype = DType::from_str(dtype)
+ if let Ok(dtype) = ob.extract::<String>(py) {
+ let dtype = DType::from_str(&dtype)
.map_err(|_| PyTypeError::new_err(format!("invalid dtype '{dtype}'")))?;
Ok(Self(dtype))
} else {
@@ -116,8 +116,8 @@ impl PyDevice {
impl<'source> FromPyObject<'source> for PyDevice {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
- let device: &str = ob.extract()?;
- let device = match device {
+ let device: String = ob.extract()?;
+ let device = match device.as_str() {
"cpu" => PyDevice::Cpu,
"cuda" => PyDevice::Cuda,
_ => Err(PyTypeError::new_err(format!("invalid device '{device}'")))?,
@@ -265,7 +265,7 @@ impl PyTensor {
} else if let Ok(TorchTensor(numpy)) = data.extract::<TorchTensor>(py) {
return PyTensor::new(py, numpy);
} else {
- let ty = data.as_ref(py).get_type();
+ let ty = data.bind(py).get_type();
Err(PyTypeError::new_err(format!(
"incorrect type {ty} for tensor"
)))?
@@ -322,7 +322,7 @@ impl PyTensor {
fn to_torch(&self, py: Python<'_>) -> PyResult<PyObject> {
let candle_values = self.values(py)?;
let torch_tensor: PyObject = py
- .import("torch")?
+ .import_bound("torch")?
.getattr("tensor")?
.call1((candle_values,))?
.extract()?;
@@ -333,7 +333,7 @@ impl PyTensor {
/// Gets the tensor's shape.
/// &RETURNS&: Tuple[int]
fn shape(&self, py: Python<'_>) -> PyObject {
- PyTuple::new(py, self.0.dims()).to_object(py)
+ PyTuple::new_bound(py, self.0.dims()).to_object(py)
}
#[getter]
@@ -347,7 +347,7 @@ impl PyTensor {
/// Gets the tensor's strides.
/// &RETURNS&: Tuple[int]
fn stride(&self, py: Python<'_>) -> PyObject {
- PyTuple::new(py, self.0.stride()).to_object(py)
+ PyTuple::new_bound(py, self.0.stride()).to_object(py)
}
#[getter]
@@ -527,7 +527,7 @@ impl PyTensor {
}
fn extract_indexer(
- py_indexer: &PyAny,
+ py_indexer: &Bound<PyAny>,
current_dim: usize,
dims: &[usize],
index_argument_count: usize,
@@ -567,7 +567,7 @@ impl PyTensor {
),
current_dim + 1,
))
- } else if py_indexer.is_ellipsis() {
+ } else if py_indexer.is(&py_indexer.py().Ellipsis()) {
// Handle '...' e.g. tensor[..., 0]
if current_dim > 0 {
return Err(PyTypeError::new_err(
@@ -586,7 +586,7 @@ impl PyTensor {
}
}
- if let Ok(tuple) = idx.downcast::<pyo3::types::PyTuple>(py) {
+ if let Ok(tuple) = idx.downcast_bound::<pyo3::types::PyTuple>(py) {
let not_none_count: usize = tuple.iter().filter(|x| !x.is_none()).count();
if not_none_count > dims.len() {
@@ -596,12 +596,12 @@ impl PyTensor {
let mut current_dim = 0;
for item in tuple.iter() {
let (indexer, new_current_dim) =
- extract_indexer(item, current_dim, dims, not_none_count)?;
+ extract_indexer(&item, current_dim, dims, not_none_count)?;
current_dim = new_current_dim;
indexers.push(indexer);
}
} else {
- let (indexer, _) = extract_indexer(idx.downcast::<PyAny>(py)?, 0, dims, 1)?;
+ let (indexer, _) = extract_indexer(idx.downcast_bound::<PyAny>(py)?, 0, dims, 1)?;
indexers.push(indexer);
}
@@ -652,7 +652,7 @@ impl PyTensor {
/// Add two tensors.
/// &RETURNS&: Tensor
- fn __add__(&self, rhs: &PyAny) -> PyResult<Self> {
+ fn __add__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
self.0.broadcast_add(&rhs.0).map_err(wrap_err)?
} else if let Ok(rhs) = rhs.extract::<f64>() {
@@ -663,13 +663,13 @@ impl PyTensor {
Ok(Self(tensor))
}
- fn __radd__(&self, rhs: &PyAny) -> PyResult<Self> {
+ fn __radd__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
self.__add__(rhs)
}
/// Multiply two tensors.
/// &RETURNS&: Tensor
- fn __mul__(&self, rhs: &PyAny) -> PyResult<Self> {
+ fn __mul__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
self.0.broadcast_mul(&rhs.0).map_err(wrap_err)?
} else if let Ok(rhs) = rhs.extract::<f64>() {
@@ -680,13 +680,13 @@ impl PyTensor {
Ok(Self(tensor))
}
- fn __rmul__(&self, rhs: &PyAny) -> PyResult<Self> {
+ fn __rmul__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
self.__mul__(rhs)
}
/// Subtract two tensors.
/// &RETURNS&: Tensor
- fn __sub__(&self, rhs: &PyAny) -> PyResult<Self> {
+ fn __sub__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
self.0.broadcast_sub(&rhs.0).map_err(wrap_err)?
} else if let Ok(rhs) = rhs.extract::<f64>() {
@@ -699,7 +699,7 @@ impl PyTensor {
/// Divide two tensors.
/// &RETURNS&: Tensor
- fn __truediv__(&self, rhs: &PyAny) -> PyResult<Self> {
+ fn __truediv__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
self.0.broadcast_div(&rhs.0).map_err(wrap_err)?
} else if let Ok(rhs) = rhs.extract::<f64>() {
@@ -711,7 +711,7 @@ impl PyTensor {
}
/// Rich-compare two tensors.
/// &RETURNS&: Tensor
- fn __richcmp__(&self, rhs: &PyAny, op: CompareOp) -> PyResult<Self> {
+ fn __richcmp__(&self, rhs: &Bound<PyAny>, op: CompareOp) -> PyResult<Self> {
let compare = |lhs: &Tensor, rhs: &Tensor| {
let t = match op {
CompareOp::Eq => lhs.eq(rhs),
@@ -957,7 +957,7 @@ impl PyTensor {
#[pyo3(signature = (*args, **kwargs), text_signature = "(self, *args, **kwargs)")]
/// Performs Tensor dtype and/or device conversion.
/// &RETURNS&: Tensor
- fn to(&self, args: &PyTuple, kwargs: Option<&PyDict>) -> PyResult<Self> {
+ fn to(&self, args: &Bound<PyTuple>, kwargs: Option<&Bound<PyDict>>) -> PyResult<Self> {
let mut device: Option<PyDevice> = None;
let mut dtype: Option<PyDType> = None;
let mut other: Option<PyTensor> = None;
@@ -1227,7 +1227,7 @@ impl PyQTensor {
///Gets the shape of the tensor.
/// &RETURNS&: Tuple[int]
fn shape(&self, py: Python<'_>) -> PyObject {
- PyTuple::new(py, self.0.shape().dims()).to_object(py)
+ PyTuple::new_bound(py, self.0.shape().dims()).to_object(py)
}
fn __repr__(&self) -> String {
@@ -1265,7 +1265,7 @@ fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> {
.into_iter()
.map(|(key, value)| (key, PyTensor(value).into_py(py)))
.collect::<Vec<_>>();
- Ok(res.into_py_dict(py).to_object(py))
+ Ok(res.into_py_dict_bound(py).to_object(py))
}
#[pyfunction]
@@ -1303,7 +1303,7 @@ fn load_ggml(
.map(|(key, qtensor)| Ok((key, PyQTensor(Arc::new(qtensor)).into_py(py))))
.collect::<::candle::Result<Vec<_>>>()
.map_err(wrap_err)?;
- let tensors = tensors.into_py_dict(py).to_object(py);
+ let tensors = tensors.into_py_dict_bound(py).to_object(py);
let hparams = [
("n_vocab", ggml.hparams.n_vocab),
("n_embd", ggml.hparams.n_embd),
@@ -1313,7 +1313,7 @@ fn load_ggml(
("n_rot", ggml.hparams.n_rot),
("ftype", ggml.hparams.ftype),
];
- let hparams = hparams.into_py_dict(py).to_object(py);
+ let hparams = hparams.into_py_dict_bound(py).to_object(py);
let vocab = ggml
.vocab
.token_score_pairs
@@ -1351,7 +1351,7 @@ fn load_gguf(
gguf_file::Value::Bool(x) => x.into_py(py),
gguf_file::Value::String(x) => x.into_py(py),
gguf_file::Value::Array(x) => {
- let list = pyo3::types::PyList::empty(py);
+ let list = pyo3::types::PyList::empty_bound(py);
for elem in x.iter() {
list.append(gguf_value_to_pyobject(elem, py)?)?;
}
@@ -1371,13 +1371,13 @@ fn load_gguf(
})
.collect::<::candle::Result<Vec<_>>>()
.map_err(wrap_err)?;
- let tensors = tensors.into_py_dict(py).to_object(py);
+ let tensors = tensors.into_py_dict_bound(py).to_object(py);
let metadata = gguf
.metadata
.iter()
.map(|(key, value)| Ok((key, gguf_value_to_pyobject(value, py)?)))
.collect::<PyResult<Vec<_>>>()?
- .into_py_dict(py)
+ .into_py_dict_bound(py)
.to_object(py);
Ok((tensors, metadata))
}
@@ -1390,7 +1390,7 @@ fn load_gguf(
fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>) -> PyResult<()> {
use ::candle::quantized::gguf_file;
- fn pyobject_to_gguf_value(v: &PyAny, py: Python<'_>) -> PyResult<gguf_file::Value> {
+ fn pyobject_to_gguf_value(v: &Bound<PyAny>, py: Python<'_>) -> PyResult<gguf_file::Value> {
let v: gguf_file::Value = if let Ok(x) = v.extract::<u8>() {
gguf_file::Value::U8(x)
} else if let Ok(x) = v.extract::<i8>() {
@@ -1418,7 +1418,7 @@ fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>)
} else if let Ok(x) = v.extract::<Vec<PyObject>>() {
let x = x
.into_iter()
- .map(|f| pyobject_to_gguf_value(f.as_ref(py), py))
+ .map(|f| pyobject_to_gguf_value(f.bind(py), py))
.collect::<PyResult<Vec<_>>>()?;
gguf_file::Value::Array(x)
} else {
@@ -1450,7 +1450,7 @@ fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>)
Ok((
key.extract::<String>()
.map_err(|_| PyErr::new::<PyValueError, _>("keys must be strings"))?,
- pyobject_to_gguf_value(value, py)?,
+ pyobject_to_gguf_value(&value.as_borrowed(), py)?,
))
})
.collect::<PyResult<Vec<_>>>()?;
@@ -1498,7 +1498,7 @@ fn get_num_threads() -> usize {
::candle::utils::get_num_threads()
}
-fn candle_utils(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
+fn candle_utils(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(cuda_is_available, m)?)?;
m.add_function(wrap_pyfunction!(get_num_threads, m)?)?;
m.add_function(wrap_pyfunction!(has_accelerate, m)?)?;
@@ -1579,7 +1579,7 @@ fn tanh(tensor: PyTensor) -> PyResult<PyTensor> {
Ok(PyTensor(s))
}
-fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
+fn candle_functional_m(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(silu, m)?)?;
m.add_function(wrap_pyfunction!(softmax, m)?)?;
m.add_function(wrap_pyfunction!(max_pool2d, m)?)?;
@@ -1591,7 +1591,7 @@ fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
}
#[cfg(feature = "onnx")]
-fn candle_onnx_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
+fn candle_onnx_m(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
use onnx::{PyONNXModel, PyONNXTensorDescriptor};
m.add_class::<PyONNXModel>()?;
m.add_class::<PyONNXTensorDescriptor>()?;
@@ -1599,18 +1599,18 @@ fn candle_onnx_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
}
#[pymodule]
-fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
- let utils = PyModule::new(py, "utils")?;
- candle_utils(py, utils)?;
- m.add_submodule(utils)?;
- let nn = PyModule::new(py, "functional")?;
- candle_functional_m(py, nn)?;
- m.add_submodule(nn)?;
+fn candle(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
+ let utils = PyModule::new_bound(py, "utils")?;
+ candle_utils(py, &utils)?;
+ m.add_submodule(&utils)?;
+ let nn = PyModule::new_bound(py, "functional")?;
+ candle_functional_m(py, &nn)?;
+ m.add_submodule(&nn)?;
#[cfg(feature = "onnx")]
{
- let onnx = PyModule::new(py, "onnx")?;
- candle_onnx_m(py, onnx)?;
- m.add_submodule(onnx)?;
+ let onnx = PyModule::new_bound(py, "onnx")?;
+ candle_onnx_m(py, &onnx)?;
+ m.add_submodule(&onnx)?;
}
m.add_class::<PyTensor>()?;
m.add_class::<PyQTensor>()?;
diff --git a/candle-pyo3/src/onnx.rs b/candle-pyo3/src/onnx.rs
index b9a0eb22..a2e9a087 100644
--- a/candle-pyo3/src/onnx.rs
+++ b/candle-pyo3/src/onnx.rs
@@ -39,7 +39,7 @@ impl PyONNXTensorDescriptor {
/// The shape of the tensor.
/// &RETURNS&: Tuple[Union[int,str,Any]]
fn shape(&self, py: Python) -> PyResult<Py<PyTuple>> {
- let shape = PyList::empty(py);
+ let shape = PyList::empty_bound(py);
if let Some(d) = &self.0.shape {
for dim in d.dim.iter() {
if let Some(value) = &dim.value {
diff --git a/candle-pyo3/stub.py b/candle-pyo3/stub.py
index 165941bd..b0e472e6 100644
--- a/candle-pyo3/stub.py
+++ b/candle-pyo3/stub.py
@@ -206,6 +206,8 @@ def write(module, directory, origin, check=False):
if check:
with open(filename, "r") as f:
data = f.read()
+ print("generated content")
+ print(pyi_content)
assert data == pyi_content, f"The content of {filename} seems outdated, please run `python stub.py`"
else:
with open(filename, "w") as f:
@@ -229,6 +231,8 @@ def write(module, directory, origin, check=False):
if check:
with open(filename, "r") as f:
data = f.read()
+ print("generated content")
+ print(py_content)
assert data == py_content, f"The content of {filename} seems outdated, please run `python stub.py`"
else:
with open(filename, "w") as f: