diff options
author | Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> | 2023-11-08 06:37:50 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-08 06:37:50 +0100 |
commit | f3a4f3db768d46defc16de48208107db1b32159d (patch) | |
tree | 21ae0872e46621656559ec0caf6d7625e6ca7e76 /candle-pyo3/py_src/candle/onnx | |
parent | 7920b45c8ac737b67e23f04297f6bd7e4860f373 (diff) | |
download | candle-f3a4f3db768d46defc16de48208107db1b32159d.tar.gz candle-f3a4f3db768d46defc16de48208107db1b32159d.tar.bz2 candle-f3a4f3db768d46defc16de48208107db1b32159d.zip |
PyO3: Add optional `candle.onnx` module (#1282)
* Start onnx integration
* Merge remote-tracking branch 'upstream/main' into feat/pyo3-onnx
* Implement ONNXModel
* `fmt`
* add `onnx` flag to python ci
* Pin `protoc` to `25.0`
* Setup `protoc` in wheel builds
* Build wheels with `onnx`
* Install `protoc` in manylinux containers
* `apt` -> `yum`
* Download `protoc` via bash script
* Back to `manylinux: auto`
* Disable `onnx` builds for linux
Diffstat (limited to 'candle-pyo3/py_src/candle/onnx')
-rw-r--r-- | candle-pyo3/py_src/candle/onnx/__init__.py | 5 | ||||
-rw-r--r-- | candle-pyo3/py_src/candle/onnx/__init__.pyi | 89 |
2 files changed, 94 insertions, 0 deletions
diff --git a/candle-pyo3/py_src/candle/onnx/__init__.py b/candle-pyo3/py_src/candle/onnx/__init__.py new file mode 100644 index 00000000..856ecd7d --- /dev/null +++ b/candle-pyo3/py_src/candle/onnx/__init__.py @@ -0,0 +1,5 @@ +# Generated content DO NOT EDIT +from .. import onnx + +ONNXModel = onnx.ONNXModel +ONNXTensorDescription = onnx.ONNXTensorDescription diff --git a/candle-pyo3/py_src/candle/onnx/__init__.pyi b/candle-pyo3/py_src/candle/onnx/__init__.pyi new file mode 100644 index 00000000..8ce1b3aa --- /dev/null +++ b/candle-pyo3/py_src/candle/onnx/__init__.pyi @@ -0,0 +1,89 @@ +# Generated content DO NOT EDIT +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence +from os import PathLike +from candle.typing import _ArrayLike, Device, Scalar, Index, Shape +from candle import Tensor, DType, QTensor + +class ONNXModel: + """ + A wrapper around an ONNX model. + """ + + def __init__(self, path: str): + pass + @property + def doc_string(self) -> str: + """ + The doc string of the model. + """ + pass + @property + def domain(self) -> str: + """ + The domain of the operator set of the model. + """ + pass + def initializers(self) -> Dict[str, Tensor]: + """ + Get the weights of the model. + """ + pass + @property + def inputs(self) -> Optional[Dict[str, ONNXTensorDescription]]: + """ + The inputs of the model. + """ + pass + @property + def ir_version(self) -> int: + """ + The version of the IR this model targets. + """ + pass + @property + def model_version(self) -> int: + """ + The version of the model. + """ + pass + @property + def outputs(self) -> Optional[Dict[str, ONNXTensorDescription]]: + """ + The outputs of the model. + """ + pass + @property + def producer_name(self) -> str: + """ + The producer of the model. + """ + pass + @property + def producer_version(self) -> str: + """ + The version of the producer of the model. + """ + pass + def run(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: + """ + Run the model on the given inputs. + """ + pass + +class ONNXTensorDescription: + """ + A wrapper around an ONNX tensor description. + """ + + @property + def dtype(self) -> DType: + """ + The data type of the tensor. + """ + pass + @property + def shape(self) -> Tuple[Union[int, str, Any]]: + """ + The shape of the tensor. + """ + pass |