diff options
author | Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> | 2023-10-17 12:07:26 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-17 11:07:26 +0100 |
commit | f9e93f5b6909b4f680c244a0d049add181675958 (patch) | |
tree | e509752e90521d6500eb22e35e56b6322a9b6706 /candle-pyo3 | |
parent | b355ab4e2e52b077e71aac46c286fbce033f36d6 (diff) | |
download | candle-f9e93f5b6909b4f680c244a0d049add181675958.tar.gz candle-f9e93f5b6909b4f680c244a0d049add181675958.tar.bz2 candle-f9e93f5b6909b4f680c244a0d049add181675958.zip |
Extend `stub.py` to accept external typehinting (#1102)
Diffstat (limited to 'candle-pyo3')
-rw-r--r-- | candle-pyo3/_additional_typing/README.md | 3 | ||||
-rw-r--r-- | candle-pyo3/_additional_typing/__init__.py | 55 | ||||
-rw-r--r-- | candle-pyo3/py_src/candle/__init__.pyi | 42 | ||||
-rw-r--r-- | candle-pyo3/py_src/candle/functional/__init__.pyi | 2 | ||||
-rw-r--r-- | candle-pyo3/py_src/candle/typing/__init__.py | 4 | ||||
-rw-r--r-- | candle-pyo3/py_src/candle/utils/__init__.pyi | 2 | ||||
-rw-r--r-- | candle-pyo3/stub.py | 42 |
7 files changed, 146 insertions, 4 deletions
diff --git a/candle-pyo3/_additional_typing/README.md b/candle-pyo3/_additional_typing/README.md new file mode 100644 index 00000000..ab5074e0 --- /dev/null +++ b/candle-pyo3/_additional_typing/README.md @@ -0,0 +1,3 @@ +This python module contains external typehinting for certain `candle` classes. This is only necessary for `magic` methodes e.g. `__add__` as their text signature cant be set via pyo3. + +The classes in this module will be parsed by the `stub.py` script and interleafed with the signatures of the actual pyo3 `candle.candle` module.
\ No newline at end of file diff --git a/candle-pyo3/_additional_typing/__init__.py b/candle-pyo3/_additional_typing/__init__.py new file mode 100644 index 00000000..0d0eec90 --- /dev/null +++ b/candle-pyo3/_additional_typing/__init__.py @@ -0,0 +1,55 @@ +from typing import Union, Sequence + + +class Tensor: + """ + This contains the type hints for the magic methodes of the `candle.Tensor` class. + """ + + def __add__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor": + """ + Add a scalar to a tensor or two tensors together. + """ + pass + + def __radd__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor": + """ + Add a scalar to a tensor or two tensors together. + """ + pass + + def __sub__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor": + """ + Subtract a scalar from a tensor or one tensor from another. + """ + pass + + def __truediv__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor": + """ + Divide a tensor by a scalar or one tensor by another. + """ + pass + + def __mul__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor": + """ + Multiply a tensor by a scalar or one tensor by another. + """ + pass + + def __rmul__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor": + """ + Multiply a tensor by a scalar or one tensor by another. + """ + pass + + def __richcmp__(self, rhs: Union["Tensor", "Scalar"], op) -> "Tensor": + """ + Compare a tensor with a scalar or one tensor with another. + """ + pass + + def __getitem__(self, index: Union["Index", "Tensor", Sequence["Index"]]) -> "Tensor": + """ + Return a slice of a tensor. + """ + pass diff --git a/candle-pyo3/py_src/candle/__init__.pyi b/candle-pyo3/py_src/candle/__init__.pyi index 414f0bc4..4096907b 100644 --- a/candle-pyo3/py_src/candle/__init__.pyi +++ b/candle-pyo3/py_src/candle/__init__.pyi @@ -1,7 +1,7 @@ # Generated content DO NOT EDIT from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence from os import PathLike -from candle.typing import _ArrayLike, Device +from candle.typing import _ArrayLike, Device, Scalar, Index class bf16(DType): pass @@ -119,6 +119,46 @@ class Tensor: def __init__(self, data: _ArrayLike): pass + def __add__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": + """ + Add a scalar to a tensor or two tensors together. + """ + pass + def __getitem__(self, index: Union[Index, Tensor, Sequence[Index]]) -> "Tensor": + """ + Return a slice of a tensor. + """ + pass + def __mul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": + """ + Multiply a tensor by a scalar or one tensor by another. + """ + pass + def __radd__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": + """ + Add a scalar to a tensor or two tensors together. + """ + pass + def __richcmp__(self, rhs: Union[Tensor, Scalar], op) -> "Tensor": + """ + Compare a tensor with a scalar or one tensor with another. + """ + pass + def __rmul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": + """ + Multiply a tensor by a scalar or one tensor by another. + """ + pass + def __sub__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": + """ + Subtract a scalar from a tensor or one tensor from another. + """ + pass + def __truediv__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": + """ + Divide a tensor by a scalar or one tensor by another. + """ + pass def argmax_keepdim(self, dim: int) -> Tensor: """ Returns the indices of the maximum value(s) across the selected dimension. diff --git a/candle-pyo3/py_src/candle/functional/__init__.pyi b/candle-pyo3/py_src/candle/functional/__init__.pyi index 6f206e40..5bf5c4c3 100644 --- a/candle-pyo3/py_src/candle/functional/__init__.pyi +++ b/candle-pyo3/py_src/candle/functional/__init__.pyi @@ -1,7 +1,7 @@ # Generated content DO NOT EDIT from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence from os import PathLike -from candle.typing import _ArrayLike, Device +from candle.typing import _ArrayLike, Device, Scalar, Index from candle import Tensor, DType, QTensor @staticmethod diff --git a/candle-pyo3/py_src/candle/typing/__init__.py b/candle-pyo3/py_src/candle/typing/__init__.py index ccdb6238..66bc3d8a 100644 --- a/candle-pyo3/py_src/candle/typing/__init__.py +++ b/candle-pyo3/py_src/candle/typing/__init__.py @@ -14,3 +14,7 @@ CPU: str = "cpu" CUDA: str = "cuda" Device = TypeVar("Device", CPU, CUDA) + +Scalar = Union[int, float] + +Index = Union[int, slice, None, "Ellipsis"] diff --git a/candle-pyo3/py_src/candle/utils/__init__.pyi b/candle-pyo3/py_src/candle/utils/__init__.pyi index 61964ffc..d3b93766 100644 --- a/candle-pyo3/py_src/candle/utils/__init__.pyi +++ b/candle-pyo3/py_src/candle/utils/__init__.pyi @@ -1,7 +1,7 @@ # Generated content DO NOT EDIT from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence from os import PathLike -from candle.typing import _ArrayLike, Device +from candle.typing import _ArrayLike, Device, Scalar, Index from candle import Tensor, DType, QTensor @staticmethod diff --git a/candle-pyo3/stub.py b/candle-pyo3/stub.py index 3100a10c..8e4318bc 100644 --- a/candle-pyo3/stub.py +++ b/candle-pyo3/stub.py @@ -5,6 +5,7 @@ import os from typing import Optional import black from pathlib import Path +import re INDENT = " " * 4 @@ -12,9 +13,11 @@ GENERATED_COMMENT = "# Generated content DO NOT EDIT\n" TYPING = """from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence from os import PathLike """ -CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device\n" +CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device, Scalar, Index\n" CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType,QTensor\n" RETURN_TYPE_MARKER = "&RETURNS&: " +ADDITIONAL_TYPEHINTS = {} +FORWARD_REF_PATTERN = re.compile(r"ForwardRef\('([^']+)'\)") def do_indent(text: Optional[str], indent: str): @@ -115,6 +118,27 @@ def pyi_file(obj, indent=""): body += f"{indent+INDENT}pass\n" body += "\n" + if obj.__name__ in ADDITIONAL_TYPEHINTS: + additional_members = inspect.getmembers(ADDITIONAL_TYPEHINTS[obj.__name__]) + additional_functions = [] + for name, member in additional_members: + if inspect.isfunction(member): + additional_functions.append((name, member)) + + def process_additional_function(fn): + signature = inspect.signature(fn) + cleaned_signature = re.sub(FORWARD_REF_PATTERN, r"\1", str(signature)) + string = f"{indent}def {fn.__name__}{cleaned_signature}:\n" + string += ( + f'{indent+INDENT}"""{indent+INDENT}{do_indent(fn.__doc__, indent+INDENT)}{indent+INDENT}"""\n' + ) + string += f"{indent+INDENT}pass\n" + string += "\n" + return string + + for name, fn in additional_functions: + body += process_additional_function(fn) + for name, fn in fns: body += pyi_file(fn, indent=indent) @@ -215,6 +239,19 @@ def write(module, directory, origin, check=False): write(submodule, os.path.join(directory, name), f"{name}", check=check) +def extract_additional_types(module): + additional_types = {} + for name, member in inspect.getmembers(module): + if inspect.isclass(member): + if hasattr(member, "__name__"): + name = member.__name__ + else: + name = str(member) + if name not in additional_types: + additional_types[name] = member + return additional_types + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--check", action="store_true") @@ -228,5 +265,8 @@ if __name__ == "__main__": directory = f"candle-pyo3/{directory}" import candle + import _additional_typing + + ADDITIONAL_TYPEHINTS = extract_additional_types(_additional_typing) write(candle.candle, directory, "candle", check=args.check) |