summaryrefslogtreecommitdiff
path: root/candle-pyo3/py_src/candle/onnx/__init__.pyi
blob: 8ce1b3aacaebbb1a70d1b0ae6a50aadea99b1dae (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
from candle.typing import _ArrayLike, Device, Scalar, Index, Shape
from candle import Tensor, DType, QTensor

class ONNXModel:
    """
    A wrapper around an ONNX model.
    """

    def __init__(self, path: str):
        pass
    @property
    def doc_string(self) -> str:
        """
        The doc string of the model.
        """
        pass
    @property
    def domain(self) -> str:
        """
        The domain of the operator set of the model.
        """
        pass
    def initializers(self) -> Dict[str, Tensor]:
        """
        Get the weights of the model.
        """
        pass
    @property
    def inputs(self) -> Optional[Dict[str, ONNXTensorDescription]]:
        """
        The inputs of the model.
        """
        pass
    @property
    def ir_version(self) -> int:
        """
        The version of the IR this model targets.
        """
        pass
    @property
    def model_version(self) -> int:
        """
        The version of the model.
        """
        pass
    @property
    def outputs(self) -> Optional[Dict[str, ONNXTensorDescription]]:
        """
        The outputs of the model.
        """
        pass
    @property
    def producer_name(self) -> str:
        """
        The producer of the model.
        """
        pass
    @property
    def producer_version(self) -> str:
        """
        The version of the producer of the model.
        """
        pass
    def run(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
        """
        Run the model on the given inputs.
        """
        pass

class ONNXTensorDescription:
    """
    A wrapper around an ONNX tensor description.
    """

    @property
    def dtype(self) -> DType:
        """
        The data type of the tensor.
        """
        pass
    @property
    def shape(self) -> Tuple[Union[int, str, Any]]:
        """
        The shape of the tensor.
        """
        pass