summaryrefslogtreecommitdiff
path: root/candle-pyo3/stub.py
diff options
context:
space:
mode:
Diffstat (limited to 'candle-pyo3/stub.py')
-rw-r--r--candle-pyo3/stub.py217
1 files changed, 217 insertions, 0 deletions
diff --git a/candle-pyo3/stub.py b/candle-pyo3/stub.py
new file mode 100644
index 00000000..b5b9256f
--- /dev/null
+++ b/candle-pyo3/stub.py
@@ -0,0 +1,217 @@
+#See: https://raw.githubusercontent.com/huggingface/tokenizers/main/bindings/python/stub.py
+import argparse
+import inspect
+import os
+from typing import Optional
+import black
+from pathlib import Path
+
+
+INDENT = " " * 4
+GENERATED_COMMENT = "# Generated content DO NOT EDIT\n"
+TYPING = """from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
+from os import PathLike
+"""
+CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device\n"
+CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType\n"
+
+
+
+def do_indent(text: Optional[str], indent: str):
+ if text is None:
+ return ""
+ return text.replace("\n", f"\n{indent}")
+
+
+def function(obj, indent:str, text_signature:str=None):
+ if text_signature is None:
+ text_signature = obj.__text_signature__
+
+ text_signature = text_signature.replace("$self", "self").lstrip().rstrip()
+ string = ""
+ string += f"{indent}def {obj.__name__}{text_signature}:\n"
+ indent += INDENT
+ string += f'{indent}"""\n'
+ string += f"{indent}{do_indent(obj.__doc__, indent)}\n"
+ string += f'{indent}"""\n'
+ string += f"{indent}pass\n"
+ string += "\n"
+ string += "\n"
+ return string
+
+
+def member_sort(member):
+ if inspect.isclass(member):
+ value = 10 + len(inspect.getmro(member))
+ else:
+ value = 1
+ return value
+
+
+def fn_predicate(obj):
+ value = inspect.ismethoddescriptor(obj) or inspect.isbuiltin(obj)
+ if value:
+ return obj.__text_signature__ and not obj.__name__.startswith("_")
+ if inspect.isgetsetdescriptor(obj):
+ return not obj.__name__.startswith("_")
+ return False
+
+
+def get_module_members(module):
+ members = [
+ member
+ for name, member in inspect.getmembers(module)
+ if not name.startswith("_") and not inspect.ismodule(member)
+ ]
+ members.sort(key=member_sort)
+ return members
+
+
+def pyi_file(obj, indent=""):
+ string = ""
+ if inspect.ismodule(obj):
+ string += GENERATED_COMMENT
+ string += TYPING
+ string += CANDLE_SPECIFIC_TYPING
+ if obj.__name__ != "candle.candle":
+ string += CANDLE_TENSOR_IMPORTS
+ members = get_module_members(obj)
+ for member in members:
+ string += pyi_file(member, indent)
+
+ elif inspect.isclass(obj):
+ indent += INDENT
+ mro = inspect.getmro(obj)
+ if len(mro) > 2:
+ inherit = f"({mro[1].__name__})"
+ else:
+ inherit = ""
+ string += f"class {obj.__name__}{inherit}:\n"
+
+ body = ""
+ if obj.__doc__:
+ body += f'{indent}"""\n{indent}{do_indent(obj.__doc__, indent)}\n{indent}"""\n'
+
+ fns = inspect.getmembers(obj, fn_predicate)
+
+ # Init
+ if obj.__text_signature__:
+ body += f"{indent}def __init__{obj.__text_signature__}:\n"
+ body += f"{indent+INDENT}pass\n"
+ body += "\n"
+
+ for (name, fn) in fns:
+ body += pyi_file(fn, indent=indent)
+
+ if not body:
+ body += f"{indent}pass\n"
+
+ string += body
+ string += "\n\n"
+
+ elif inspect.isbuiltin(obj):
+ string += f"{indent}@staticmethod\n"
+ string += function(obj, indent)
+
+ elif inspect.ismethoddescriptor(obj):
+ string += function(obj, indent)
+
+ elif inspect.isgetsetdescriptor(obj):
+ # TODO it would be interesing to add the setter maybe ?
+ string += f"{indent}@property\n"
+ string += function(obj, indent, text_signature="(self)")
+
+ elif obj.__class__.__name__ == "DType":
+ string += f"class {str(obj).lower()}(DType):\n"
+ string += f"{indent+INDENT}pass\n"
+ else:
+ raise Exception(f"Object {obj} is not supported")
+ return string
+
+
+def py_file(module, origin):
+ members = get_module_members(module)
+
+ string = GENERATED_COMMENT
+ string += f"from .. import {origin}\n"
+ string += "\n"
+ for member in members:
+ if hasattr(member, "__name__"):
+ name = member.__name__
+ else:
+ name = str(member)
+ string += f"{name} = {origin}.{name}\n"
+ return string
+
+
+def do_black(content, is_pyi):
+ mode = black.Mode(
+ target_versions={black.TargetVersion.PY35},
+ line_length=119,
+ is_pyi=is_pyi,
+ string_normalization=True,
+ experimental_string_processing=False,
+ )
+ try:
+ return black.format_file_contents(content, fast=True, mode=mode)
+ except black.NothingChanged:
+ return content
+
+
+def write(module, directory, origin, check=False):
+ submodules = [(name, member) for name, member in inspect.getmembers(module) if inspect.ismodule(member)]
+
+ filename = os.path.join(directory, "__init__.pyi")
+ pyi_content = pyi_file(module)
+ pyi_content = do_black(pyi_content, is_pyi=True)
+ os.makedirs(directory, exist_ok=True)
+ if check:
+ with open(filename, "r") as f:
+ data = f.read()
+ assert data == pyi_content, f"The content of {filename} seems outdated, please run `python stub.py`"
+ else:
+ with open(filename, "w") as f:
+ f.write(pyi_content)
+
+ filename = os.path.join(directory, "__init__.py")
+ py_content = py_file(module, origin)
+ py_content = do_black(py_content, is_pyi=False)
+ os.makedirs(directory, exist_ok=True)
+
+ is_auto = False
+ if not os.path.exists(filename):
+ is_auto = True
+ else:
+ with open(filename, "r") as f:
+ line = f.readline()
+ if line == GENERATED_COMMENT:
+ is_auto = True
+
+ if is_auto:
+ if check:
+ with open(filename, "r") as f:
+ data = f.read()
+ assert data == py_content, f"The content of {filename} seems outdated, please run `python stub.py`"
+ else:
+ with open(filename, "w") as f:
+ f.write(py_content)
+
+ for name, submodule in submodules:
+ write(submodule, os.path.join(directory, name), f"{name}", check=check)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--check", action="store_true")
+
+ args = parser.parse_args()
+
+ #Enable execution from the candle and candle-pyo3 directories
+ cwd = Path.cwd()
+ directory = "py_src/candle/"
+ if cwd.name != "candle-pyo3":
+ directory = f"candle-pyo3/{directory}"
+
+ import candle
+
+ write(candle.candle, directory, "candle", check=args.check)