summaryrefslogtreecommitdiff
path: root/candle-pyo3
diff options
context:
space:
mode:
authorLukas Kreussel <65088241+LLukas22@users.noreply.github.com>2023-10-17 12:07:26 +0200
committerGitHub <noreply@github.com>2023-10-17 11:07:26 +0100
commitf9e93f5b6909b4f680c244a0d049add181675958 (patch)
treee509752e90521d6500eb22e35e56b6322a9b6706 /candle-pyo3
parentb355ab4e2e52b077e71aac46c286fbce033f36d6 (diff)
downloadcandle-f9e93f5b6909b4f680c244a0d049add181675958.tar.gz
candle-f9e93f5b6909b4f680c244a0d049add181675958.tar.bz2
candle-f9e93f5b6909b4f680c244a0d049add181675958.zip
Extend `stub.py` to accept external typehinting (#1102)
Diffstat (limited to 'candle-pyo3')
-rw-r--r--candle-pyo3/_additional_typing/README.md3
-rw-r--r--candle-pyo3/_additional_typing/__init__.py55
-rw-r--r--candle-pyo3/py_src/candle/__init__.pyi42
-rw-r--r--candle-pyo3/py_src/candle/functional/__init__.pyi2
-rw-r--r--candle-pyo3/py_src/candle/typing/__init__.py4
-rw-r--r--candle-pyo3/py_src/candle/utils/__init__.pyi2
-rw-r--r--candle-pyo3/stub.py42
7 files changed, 146 insertions, 4 deletions
diff --git a/candle-pyo3/_additional_typing/README.md b/candle-pyo3/_additional_typing/README.md
new file mode 100644
index 00000000..ab5074e0
--- /dev/null
+++ b/candle-pyo3/_additional_typing/README.md
@@ -0,0 +1,3 @@
+This python module contains external typehinting for certain `candle` classes. This is only necessary for `magic` methodes e.g. `__add__` as their text signature cant be set via pyo3.
+
+The classes in this module will be parsed by the `stub.py` script and interleafed with the signatures of the actual pyo3 `candle.candle` module. \ No newline at end of file
diff --git a/candle-pyo3/_additional_typing/__init__.py b/candle-pyo3/_additional_typing/__init__.py
new file mode 100644
index 00000000..0d0eec90
--- /dev/null
+++ b/candle-pyo3/_additional_typing/__init__.py
@@ -0,0 +1,55 @@
+from typing import Union, Sequence
+
+
+class Tensor:
+ """
+ This contains the type hints for the magic methodes of the `candle.Tensor` class.
+ """
+
+ def __add__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
+ """
+ Add a scalar to a tensor or two tensors together.
+ """
+ pass
+
+ def __radd__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
+ """
+ Add a scalar to a tensor or two tensors together.
+ """
+ pass
+
+ def __sub__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
+ """
+ Subtract a scalar from a tensor or one tensor from another.
+ """
+ pass
+
+ def __truediv__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
+ """
+ Divide a tensor by a scalar or one tensor by another.
+ """
+ pass
+
+ def __mul__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
+ """
+ Multiply a tensor by a scalar or one tensor by another.
+ """
+ pass
+
+ def __rmul__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
+ """
+ Multiply a tensor by a scalar or one tensor by another.
+ """
+ pass
+
+ def __richcmp__(self, rhs: Union["Tensor", "Scalar"], op) -> "Tensor":
+ """
+ Compare a tensor with a scalar or one tensor with another.
+ """
+ pass
+
+ def __getitem__(self, index: Union["Index", "Tensor", Sequence["Index"]]) -> "Tensor":
+ """
+ Return a slice of a tensor.
+ """
+ pass
diff --git a/candle-pyo3/py_src/candle/__init__.pyi b/candle-pyo3/py_src/candle/__init__.pyi
index 414f0bc4..4096907b 100644
--- a/candle-pyo3/py_src/candle/__init__.pyi
+++ b/candle-pyo3/py_src/candle/__init__.pyi
@@ -1,7 +1,7 @@
# Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
-from candle.typing import _ArrayLike, Device
+from candle.typing import _ArrayLike, Device, Scalar, Index
class bf16(DType):
pass
@@ -119,6 +119,46 @@ class Tensor:
def __init__(self, data: _ArrayLike):
pass
+ def __add__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
+ """
+ Add a scalar to a tensor or two tensors together.
+ """
+ pass
+ def __getitem__(self, index: Union[Index, Tensor, Sequence[Index]]) -> "Tensor":
+ """
+ Return a slice of a tensor.
+ """
+ pass
+ def __mul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
+ """
+ Multiply a tensor by a scalar or one tensor by another.
+ """
+ pass
+ def __radd__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
+ """
+ Add a scalar to a tensor or two tensors together.
+ """
+ pass
+ def __richcmp__(self, rhs: Union[Tensor, Scalar], op) -> "Tensor":
+ """
+ Compare a tensor with a scalar or one tensor with another.
+ """
+ pass
+ def __rmul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
+ """
+ Multiply a tensor by a scalar or one tensor by another.
+ """
+ pass
+ def __sub__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
+ """
+ Subtract a scalar from a tensor or one tensor from another.
+ """
+ pass
+ def __truediv__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
+ """
+ Divide a tensor by a scalar or one tensor by another.
+ """
+ pass
def argmax_keepdim(self, dim: int) -> Tensor:
"""
Returns the indices of the maximum value(s) across the selected dimension.
diff --git a/candle-pyo3/py_src/candle/functional/__init__.pyi b/candle-pyo3/py_src/candle/functional/__init__.pyi
index 6f206e40..5bf5c4c3 100644
--- a/candle-pyo3/py_src/candle/functional/__init__.pyi
+++ b/candle-pyo3/py_src/candle/functional/__init__.pyi
@@ -1,7 +1,7 @@
# Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
-from candle.typing import _ArrayLike, Device
+from candle.typing import _ArrayLike, Device, Scalar, Index
from candle import Tensor, DType, QTensor
@staticmethod
diff --git a/candle-pyo3/py_src/candle/typing/__init__.py b/candle-pyo3/py_src/candle/typing/__init__.py
index ccdb6238..66bc3d8a 100644
--- a/candle-pyo3/py_src/candle/typing/__init__.py
+++ b/candle-pyo3/py_src/candle/typing/__init__.py
@@ -14,3 +14,7 @@ CPU: str = "cpu"
CUDA: str = "cuda"
Device = TypeVar("Device", CPU, CUDA)
+
+Scalar = Union[int, float]
+
+Index = Union[int, slice, None, "Ellipsis"]
diff --git a/candle-pyo3/py_src/candle/utils/__init__.pyi b/candle-pyo3/py_src/candle/utils/__init__.pyi
index 61964ffc..d3b93766 100644
--- a/candle-pyo3/py_src/candle/utils/__init__.pyi
+++ b/candle-pyo3/py_src/candle/utils/__init__.pyi
@@ -1,7 +1,7 @@
# Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
-from candle.typing import _ArrayLike, Device
+from candle.typing import _ArrayLike, Device, Scalar, Index
from candle import Tensor, DType, QTensor
@staticmethod
diff --git a/candle-pyo3/stub.py b/candle-pyo3/stub.py
index 3100a10c..8e4318bc 100644
--- a/candle-pyo3/stub.py
+++ b/candle-pyo3/stub.py
@@ -5,6 +5,7 @@ import os
from typing import Optional
import black
from pathlib import Path
+import re
INDENT = " " * 4
@@ -12,9 +13,11 @@ GENERATED_COMMENT = "# Generated content DO NOT EDIT\n"
TYPING = """from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
"""
-CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device\n"
+CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device, Scalar, Index\n"
CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType,QTensor\n"
RETURN_TYPE_MARKER = "&RETURNS&: "
+ADDITIONAL_TYPEHINTS = {}
+FORWARD_REF_PATTERN = re.compile(r"ForwardRef\('([^']+)'\)")
def do_indent(text: Optional[str], indent: str):
@@ -115,6 +118,27 @@ def pyi_file(obj, indent=""):
body += f"{indent+INDENT}pass\n"
body += "\n"
+ if obj.__name__ in ADDITIONAL_TYPEHINTS:
+ additional_members = inspect.getmembers(ADDITIONAL_TYPEHINTS[obj.__name__])
+ additional_functions = []
+ for name, member in additional_members:
+ if inspect.isfunction(member):
+ additional_functions.append((name, member))
+
+ def process_additional_function(fn):
+ signature = inspect.signature(fn)
+ cleaned_signature = re.sub(FORWARD_REF_PATTERN, r"\1", str(signature))
+ string = f"{indent}def {fn.__name__}{cleaned_signature}:\n"
+ string += (
+ f'{indent+INDENT}"""{indent+INDENT}{do_indent(fn.__doc__, indent+INDENT)}{indent+INDENT}"""\n'
+ )
+ string += f"{indent+INDENT}pass\n"
+ string += "\n"
+ return string
+
+ for name, fn in additional_functions:
+ body += process_additional_function(fn)
+
for name, fn in fns:
body += pyi_file(fn, indent=indent)
@@ -215,6 +239,19 @@ def write(module, directory, origin, check=False):
write(submodule, os.path.join(directory, name), f"{name}", check=check)
+def extract_additional_types(module):
+ additional_types = {}
+ for name, member in inspect.getmembers(module):
+ if inspect.isclass(member):
+ if hasattr(member, "__name__"):
+ name = member.__name__
+ else:
+ name = str(member)
+ if name not in additional_types:
+ additional_types[name] = member
+ return additional_types
+
+
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--check", action="store_true")
@@ -228,5 +265,8 @@ if __name__ == "__main__":
directory = f"candle-pyo3/{directory}"
import candle
+ import _additional_typing
+
+ ADDITIONAL_TYPEHINTS = extract_additional_types(_additional_typing)
write(candle.candle, directory, "candle", check=args.check)