summaryrefslogtreecommitdiff
path: root/candle-pyo3
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-10-06 10:09:38 +0200
committerGitHub <noreply@github.com>2024-10-06 10:09:38 +0200
commitf856b5c3a75028d384c26e36501d429091662cd3 (patch)
tree811569350d124b23bb54f5381756ec7cd0b34278 /candle-pyo3
parentd2e432914ec495baff1db29799fe316b9190b0e9 (diff)
downloadcandle-f856b5c3a75028d384c26e36501d429091662cd3.tar.gz
candle-f856b5c3a75028d384c26e36501d429091662cd3.tar.bz2
candle-f856b5c3a75028d384c26e36501d429091662cd3.zip
pyo3 update. (#2545)
* pyo3 update. * Stub fix.
Diffstat (limited to 'candle-pyo3')
-rw-r--r--candle-pyo3/Cargo.toml4
-rw-r--r--candle-pyo3/py_src/candle/utils/__init__.pyi10
-rw-r--r--candle-pyo3/src/lib.rs19
-rw-r--r--candle-pyo3/src/shape.rs12
4 files changed, 20 insertions, 25 deletions
diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml
index 88001334..2776a3f7 100644
--- a/candle-pyo3/Cargo.toml
+++ b/candle-pyo3/Cargo.toml
@@ -20,10 +20,10 @@ candle-nn = { workspace = true }
candle-onnx = { workspace = true, optional = true }
half = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
-pyo3 = { version = "0.21.0", features = ["extension-module", "abi3-py38"] }
+pyo3 = { version = "0.22.0", features = ["extension-module", "abi3-py38"] }
[build-dependencies]
-pyo3-build-config = "0.21"
+pyo3-build-config = "0.22"
[features]
default = []
diff --git a/candle-pyo3/py_src/candle/utils/__init__.pyi b/candle-pyo3/py_src/candle/utils/__init__.pyi
index c9a9f9f3..94c32283 100644
--- a/candle-pyo3/py_src/candle/utils/__init__.pyi
+++ b/candle-pyo3/py_src/candle/utils/__init__.pyi
@@ -33,9 +33,7 @@ def has_mkl() -> bool:
pass
@staticmethod
-def load_ggml(
- path: Union[str, PathLike], device: Optional[Device] = None
-) -> Tuple[Dict[str, QTensor], Dict[str, Any], List[str]]:
+def load_ggml(path, device=None) -> Tuple[Dict[str, QTensor], Dict[str, Any], List[str]]:
"""
Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
@@ -43,9 +41,7 @@ def load_ggml(
pass
@staticmethod
-def load_gguf(
- path: Union[str, PathLike], device: Optional[Device] = None
-) -> Tuple[Dict[str, QTensor], Dict[str, Any]]:
+def load_gguf(path, device=None) -> Tuple[Dict[str, QTensor], Dict[str, Any]]:
"""
Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
and the second maps metadata keys to metadata values.
@@ -60,7 +56,7 @@ def load_safetensors(path: Union[str, PathLike]) -> Dict[str, Tensor]:
pass
@staticmethod
-def save_gguf(path: Union[str, PathLike], tensors: Dict[str, QTensor], metadata: Dict[str, Any]):
+def save_gguf(path, tensors, metadata):
"""
Save quanitzed tensors and metadata to a GGUF file.
"""
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index 0da2c700..722b5e3a 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -6,7 +6,6 @@ use pyo3::types::{IntoPyDict, PyDict, PyTuple};
use pyo3::ToPyObject;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
-use std::os::raw::c_long;
use std::sync::Arc;
use half::{bf16, f16};
@@ -115,7 +114,7 @@ impl PyDevice {
}
impl<'source> FromPyObject<'source> for PyDevice {
- fn extract(ob: &'source PyAny) -> PyResult<Self> {
+ fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult<Self> {
let device: String = ob.extract()?;
let device = match device.as_str() {
"cpu" => PyDevice::Cpu,
@@ -217,11 +216,11 @@ enum Indexer {
IndexSelect(Tensor),
}
-#[derive(Clone, Debug)]
+#[derive(Debug)]
struct TorchTensor(PyObject);
impl<'source> pyo3::FromPyObject<'source> for TorchTensor {
- fn extract(ob: &'source PyAny) -> PyResult<Self> {
+ fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult<Self> {
let numpy_value: PyObject = ob.getattr("numpy")?.call0()?.extract()?;
Ok(TorchTensor(numpy_value))
}
@@ -540,7 +539,7 @@ impl PyTensor {
))
} else if let Ok(slice) = py_indexer.downcast::<pyo3::types::PySlice>() {
// Handle a single slice e.g. tensor[0:1] or tensor[0:-1]
- let index = slice.indices(dims[current_dim] as c_long)?;
+ let index = slice.indices(dims[current_dim] as isize)?;
Ok((
Indexer::Slice(index.start as usize, index.stop as usize),
current_dim + 1,
@@ -1284,7 +1283,7 @@ fn save_safetensors(
}
#[pyfunction]
-#[pyo3(text_signature = "(path:Union[str,PathLike], device: Optional[Device] = None)")]
+#[pyo3(signature = (path, device = None))]
/// Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
/// a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
/// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any], List[str]]
@@ -1325,7 +1324,7 @@ fn load_ggml(
}
#[pyfunction]
-#[pyo3(text_signature = "(path:Union[str,PathLike], device: Optional[Device] = None)")]
+#[pyo3(signature = (path, device = None))]
/// Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
/// and the second maps metadata keys to metadata values.
/// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any]]
@@ -1384,7 +1383,7 @@ fn load_gguf(
#[pyfunction]
#[pyo3(
- text_signature = "(path:Union[str,PathLike], tensors:Dict[str,QTensor], metadata:Dict[str,Any])"
+ signature = (path, tensors, metadata)
)]
/// Save quanitzed tensors and metadata to a GGUF file.
fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>) -> PyResult<()> {
@@ -1430,7 +1429,7 @@ fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>)
Ok(v)
}
let tensors = tensors
- .extract::<&PyDict>(py)
+ .downcast_bound::<PyDict>(py)
.map_err(|_| PyErr::new::<PyValueError, _>("expected a dict"))?
.iter()
.map(|(key, value)| {
@@ -1443,7 +1442,7 @@ fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>)
.collect::<PyResult<Vec<_>>>()?;
let metadata = metadata
- .extract::<&PyDict>(py)
+ .downcast_bound::<PyDict>(py)
.map_err(|_| PyErr::new::<PyValueError, _>("expected a dict"))?
.iter()
.map(|(key, value)| {
diff --git a/candle-pyo3/src/shape.rs b/candle-pyo3/src/shape.rs
index 2668b733..b9bc6789 100644
--- a/candle-pyo3/src/shape.rs
+++ b/candle-pyo3/src/shape.rs
@@ -6,7 +6,7 @@ use pyo3::prelude::*;
pub struct PyShape(Vec<usize>);
impl<'source> pyo3::FromPyObject<'source> for PyShape {
- fn extract(ob: &'source PyAny) -> PyResult<Self> {
+ fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult<Self> {
if ob.is_none() {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Shape cannot be None",
@@ -16,10 +16,10 @@ impl<'source> pyo3::FromPyObject<'source> for PyShape {
let tuple = ob.downcast::<pyo3::types::PyTuple>()?;
if tuple.len() == 1 {
let first_element = tuple.get_item(0)?;
- let dims: Vec<usize> = pyo3::FromPyObject::extract(first_element)?;
+ let dims: Vec<usize> = pyo3::FromPyObject::extract_bound(&first_element)?;
Ok(PyShape(dims))
} else {
- let dims: Vec<usize> = pyo3::FromPyObject::extract(tuple)?;
+ let dims: Vec<usize> = pyo3::FromPyObject::extract_bound(tuple)?;
Ok(PyShape(dims))
}
}
@@ -36,7 +36,7 @@ impl From<PyShape> for ::candle::Shape {
pub struct PyShapeWithHole(Vec<isize>);
impl<'source> pyo3::FromPyObject<'source> for PyShapeWithHole {
- fn extract(ob: &'source PyAny) -> PyResult<Self> {
+ fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult<Self> {
if ob.is_none() {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Shape cannot be None",
@@ -46,9 +46,9 @@ impl<'source> pyo3::FromPyObject<'source> for PyShapeWithHole {
let tuple = ob.downcast::<pyo3::types::PyTuple>()?;
let dims: Vec<isize> = if tuple.len() == 1 {
let first_element = tuple.get_item(0)?;
- pyo3::FromPyObject::extract(first_element)?
+ pyo3::FromPyObject::extract_bound(&first_element)?
} else {
- pyo3::FromPyObject::extract(tuple)?
+ pyo3::FromPyObject::extract_bound(tuple)?
};
// Ensure we have only positive numbers and at most one "hole" (-1)