From f9e93f5b6909b4f680c244a0d049add181675958 Mon Sep 17 00:00:00 2001
From: Lukas Kreussel <65088241+LLukas22@users.noreply.github.com>
Date: Tue, 17 Oct 2023 12:07:26 +0200
Subject: Extend `stub.py` to accept external typehinting (#1102)

---
 candle-pyo3/_additional_typing/README.md          |  3 ++
 candle-pyo3/_additional_typing/__init__.py        | 55 +++++++++++++++++++++++
 candle-pyo3/py_src/candle/__init__.pyi            | 42 ++++++++++++++++-
 candle-pyo3/py_src/candle/functional/__init__.pyi |  2 +-
 candle-pyo3/py_src/candle/typing/__init__.py      |  4 ++
 candle-pyo3/py_src/candle/utils/__init__.pyi      |  2 +-
 candle-pyo3/stub.py                               | 42 ++++++++++++++++-
 7 files changed, 146 insertions(+), 4 deletions(-)
 create mode 100644 candle-pyo3/_additional_typing/README.md
 create mode 100644 candle-pyo3/_additional_typing/__init__.py

(limited to 'candle-pyo3')

diff --git a/candle-pyo3/_additional_typing/README.md b/candle-pyo3/_additional_typing/README.md
new file mode 100644
index 00000000..ab5074e0
--- /dev/null
+++ b/candle-pyo3/_additional_typing/README.md
@@ -0,0 +1,3 @@
+This python module contains external typehinting for certain `candle` classes. This is only necessary for `magic` methodes e.g. `__add__` as their text signature cant be set via pyo3.
+
+The classes in this module will be parsed by the `stub.py` script and interleafed with the signatures of the actual pyo3 `candle.candle` module.
\ No newline at end of file
diff --git a/candle-pyo3/_additional_typing/__init__.py b/candle-pyo3/_additional_typing/__init__.py
new file mode 100644
index 00000000..0d0eec90
--- /dev/null
+++ b/candle-pyo3/_additional_typing/__init__.py
@@ -0,0 +1,55 @@
+from typing import Union, Sequence
+
+
+class Tensor:
+    """
+    This contains the type hints for the magic methodes of the `candle.Tensor` class.
+    """
+
+    def __add__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
+        """
+        Add a scalar to a tensor or two tensors together.
+        """
+        pass
+
+    def __radd__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
+        """
+        Add a scalar to a tensor or two tensors together.
+        """
+        pass
+
+    def __sub__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
+        """
+        Subtract a scalar from a tensor or one tensor from another.
+        """
+        pass
+
+    def __truediv__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
+        """
+        Divide a tensor by a scalar or one tensor by another.
+        """
+        pass
+
+    def __mul__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
+        """
+        Multiply a tensor by a scalar or one tensor by another.
+        """
+        pass
+
+    def __rmul__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
+        """
+        Multiply a tensor by a scalar or one tensor by another.
+        """
+        pass
+
+    def __richcmp__(self, rhs: Union["Tensor", "Scalar"], op) -> "Tensor":
+        """
+        Compare a tensor with a scalar or one tensor with another.
+        """
+        pass
+
+    def __getitem__(self, index: Union["Index", "Tensor", Sequence["Index"]]) -> "Tensor":
+        """
+        Return a slice of a tensor.
+        """
+        pass
diff --git a/candle-pyo3/py_src/candle/__init__.pyi b/candle-pyo3/py_src/candle/__init__.pyi
index 414f0bc4..4096907b 100644
--- a/candle-pyo3/py_src/candle/__init__.pyi
+++ b/candle-pyo3/py_src/candle/__init__.pyi
@@ -1,7 +1,7 @@
 # Generated content DO NOT EDIT
 from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
 from os import PathLike
-from candle.typing import _ArrayLike, Device
+from candle.typing import _ArrayLike, Device, Scalar, Index
 
 class bf16(DType):
     pass
@@ -119,6 +119,46 @@ class Tensor:
 
     def __init__(self, data: _ArrayLike):
         pass
+    def __add__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
+        """
+        Add a scalar to a tensor or two tensors together.
+        """
+        pass
+    def __getitem__(self, index: Union[Index, Tensor, Sequence[Index]]) -> "Tensor":
+        """
+        Return a slice of a tensor.
+        """
+        pass
+    def __mul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
+        """
+        Multiply a tensor by a scalar or one tensor by another.
+        """
+        pass
+    def __radd__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
+        """
+        Add a scalar to a tensor or two tensors together.
+        """
+        pass
+    def __richcmp__(self, rhs: Union[Tensor, Scalar], op) -> "Tensor":
+        """
+        Compare a tensor with a scalar or one tensor with another.
+        """
+        pass
+    def __rmul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
+        """
+        Multiply a tensor by a scalar or one tensor by another.
+        """
+        pass
+    def __sub__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
+        """
+        Subtract a scalar from a tensor or one tensor from another.
+        """
+        pass
+    def __truediv__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
+        """
+        Divide a tensor by a scalar or one tensor by another.
+        """
+        pass
     def argmax_keepdim(self, dim: int) -> Tensor:
         """
         Returns the indices of the maximum value(s) across the selected dimension.
diff --git a/candle-pyo3/py_src/candle/functional/__init__.pyi b/candle-pyo3/py_src/candle/functional/__init__.pyi
index 6f206e40..5bf5c4c3 100644
--- a/candle-pyo3/py_src/candle/functional/__init__.pyi
+++ b/candle-pyo3/py_src/candle/functional/__init__.pyi
@@ -1,7 +1,7 @@
 # Generated content DO NOT EDIT
 from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
 from os import PathLike
-from candle.typing import _ArrayLike, Device
+from candle.typing import _ArrayLike, Device, Scalar, Index
 from candle import Tensor, DType, QTensor
 
 @staticmethod
diff --git a/candle-pyo3/py_src/candle/typing/__init__.py b/candle-pyo3/py_src/candle/typing/__init__.py
index ccdb6238..66bc3d8a 100644
--- a/candle-pyo3/py_src/candle/typing/__init__.py
+++ b/candle-pyo3/py_src/candle/typing/__init__.py
@@ -14,3 +14,7 @@ CPU: str = "cpu"
 CUDA: str = "cuda"
 
 Device = TypeVar("Device", CPU, CUDA)
+
+Scalar = Union[int, float]
+
+Index = Union[int, slice, None, "Ellipsis"]
diff --git a/candle-pyo3/py_src/candle/utils/__init__.pyi b/candle-pyo3/py_src/candle/utils/__init__.pyi
index 61964ffc..d3b93766 100644
--- a/candle-pyo3/py_src/candle/utils/__init__.pyi
+++ b/candle-pyo3/py_src/candle/utils/__init__.pyi
@@ -1,7 +1,7 @@
 # Generated content DO NOT EDIT
 from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
 from os import PathLike
-from candle.typing import _ArrayLike, Device
+from candle.typing import _ArrayLike, Device, Scalar, Index
 from candle import Tensor, DType, QTensor
 
 @staticmethod
diff --git a/candle-pyo3/stub.py b/candle-pyo3/stub.py
index 3100a10c..8e4318bc 100644
--- a/candle-pyo3/stub.py
+++ b/candle-pyo3/stub.py
@@ -5,6 +5,7 @@ import os
 from typing import Optional
 import black
 from pathlib import Path
+import re
 
 
 INDENT = " " * 4
@@ -12,9 +13,11 @@ GENERATED_COMMENT = "# Generated content DO NOT EDIT\n"
 TYPING = """from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
 from os import PathLike
 """
-CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device\n"
+CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device, Scalar, Index\n"
 CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType,QTensor\n"
 RETURN_TYPE_MARKER = "&RETURNS&: "
+ADDITIONAL_TYPEHINTS = {}
+FORWARD_REF_PATTERN = re.compile(r"ForwardRef\('([^']+)'\)")
 
 
 def do_indent(text: Optional[str], indent: str):
@@ -115,6 +118,27 @@ def pyi_file(obj, indent=""):
             body += f"{indent+INDENT}pass\n"
             body += "\n"
 
+        if obj.__name__ in ADDITIONAL_TYPEHINTS:
+            additional_members = inspect.getmembers(ADDITIONAL_TYPEHINTS[obj.__name__])
+            additional_functions = []
+            for name, member in additional_members:
+                if inspect.isfunction(member):
+                    additional_functions.append((name, member))
+
+            def process_additional_function(fn):
+                signature = inspect.signature(fn)
+                cleaned_signature = re.sub(FORWARD_REF_PATTERN, r"\1", str(signature))
+                string = f"{indent}def {fn.__name__}{cleaned_signature}:\n"
+                string += (
+                    f'{indent+INDENT}"""{indent+INDENT}{do_indent(fn.__doc__, indent+INDENT)}{indent+INDENT}"""\n'
+                )
+                string += f"{indent+INDENT}pass\n"
+                string += "\n"
+                return string
+
+            for name, fn in additional_functions:
+                body += process_additional_function(fn)
+
         for name, fn in fns:
             body += pyi_file(fn, indent=indent)
 
@@ -215,6 +239,19 @@ def write(module, directory, origin, check=False):
         write(submodule, os.path.join(directory, name), f"{name}", check=check)
 
 
+def extract_additional_types(module):
+    additional_types = {}
+    for name, member in inspect.getmembers(module):
+        if inspect.isclass(member):
+            if hasattr(member, "__name__"):
+                name = member.__name__
+            else:
+                name = str(member)
+            if name not in additional_types:
+                additional_types[name] = member
+    return additional_types
+
+
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
     parser.add_argument("--check", action="store_true")
@@ -228,5 +265,8 @@ if __name__ == "__main__":
         directory = f"candle-pyo3/{directory}"
 
     import candle
+    import _additional_typing
+
+    ADDITIONAL_TYPEHINTS = extract_additional_types(_additional_typing)
 
     write(candle.candle, directory, "candle", check=args.check)
-- 
cgit v1.2.3