diff options
Diffstat (limited to 'candle-pyo3/stub.py')
-rw-r--r-- | candle-pyo3/stub.py | 217 |
1 files changed, 217 insertions, 0 deletions
diff --git a/candle-pyo3/stub.py b/candle-pyo3/stub.py new file mode 100644 index 00000000..b5b9256f --- /dev/null +++ b/candle-pyo3/stub.py @@ -0,0 +1,217 @@ +#See: https://raw.githubusercontent.com/huggingface/tokenizers/main/bindings/python/stub.py +import argparse +import inspect +import os +from typing import Optional +import black +from pathlib import Path + + +INDENT = " " * 4 +GENERATED_COMMENT = "# Generated content DO NOT EDIT\n" +TYPING = """from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence +from os import PathLike +""" +CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device\n" +CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType\n" + + + +def do_indent(text: Optional[str], indent: str): + if text is None: + return "" + return text.replace("\n", f"\n{indent}") + + +def function(obj, indent:str, text_signature:str=None): + if text_signature is None: + text_signature = obj.__text_signature__ + + text_signature = text_signature.replace("$self", "self").lstrip().rstrip() + string = "" + string += f"{indent}def {obj.__name__}{text_signature}:\n" + indent += INDENT + string += f'{indent}"""\n' + string += f"{indent}{do_indent(obj.__doc__, indent)}\n" + string += f'{indent}"""\n' + string += f"{indent}pass\n" + string += "\n" + string += "\n" + return string + + +def member_sort(member): + if inspect.isclass(member): + value = 10 + len(inspect.getmro(member)) + else: + value = 1 + return value + + +def fn_predicate(obj): + value = inspect.ismethoddescriptor(obj) or inspect.isbuiltin(obj) + if value: + return obj.__text_signature__ and not obj.__name__.startswith("_") + if inspect.isgetsetdescriptor(obj): + return not obj.__name__.startswith("_") + return False + + +def get_module_members(module): + members = [ + member + for name, member in inspect.getmembers(module) + if not name.startswith("_") and not inspect.ismodule(member) + ] + members.sort(key=member_sort) + return members + + +def pyi_file(obj, indent=""): + string = "" + if inspect.ismodule(obj): + string += GENERATED_COMMENT + string += TYPING + string += CANDLE_SPECIFIC_TYPING + if obj.__name__ != "candle.candle": + string += CANDLE_TENSOR_IMPORTS + members = get_module_members(obj) + for member in members: + string += pyi_file(member, indent) + + elif inspect.isclass(obj): + indent += INDENT + mro = inspect.getmro(obj) + if len(mro) > 2: + inherit = f"({mro[1].__name__})" + else: + inherit = "" + string += f"class {obj.__name__}{inherit}:\n" + + body = "" + if obj.__doc__: + body += f'{indent}"""\n{indent}{do_indent(obj.__doc__, indent)}\n{indent}"""\n' + + fns = inspect.getmembers(obj, fn_predicate) + + # Init + if obj.__text_signature__: + body += f"{indent}def __init__{obj.__text_signature__}:\n" + body += f"{indent+INDENT}pass\n" + body += "\n" + + for (name, fn) in fns: + body += pyi_file(fn, indent=indent) + + if not body: + body += f"{indent}pass\n" + + string += body + string += "\n\n" + + elif inspect.isbuiltin(obj): + string += f"{indent}@staticmethod\n" + string += function(obj, indent) + + elif inspect.ismethoddescriptor(obj): + string += function(obj, indent) + + elif inspect.isgetsetdescriptor(obj): + # TODO it would be interesing to add the setter maybe ? + string += f"{indent}@property\n" + string += function(obj, indent, text_signature="(self)") + + elif obj.__class__.__name__ == "DType": + string += f"class {str(obj).lower()}(DType):\n" + string += f"{indent+INDENT}pass\n" + else: + raise Exception(f"Object {obj} is not supported") + return string + + +def py_file(module, origin): + members = get_module_members(module) + + string = GENERATED_COMMENT + string += f"from .. import {origin}\n" + string += "\n" + for member in members: + if hasattr(member, "__name__"): + name = member.__name__ + else: + name = str(member) + string += f"{name} = {origin}.{name}\n" + return string + + +def do_black(content, is_pyi): + mode = black.Mode( + target_versions={black.TargetVersion.PY35}, + line_length=119, + is_pyi=is_pyi, + string_normalization=True, + experimental_string_processing=False, + ) + try: + return black.format_file_contents(content, fast=True, mode=mode) + except black.NothingChanged: + return content + + +def write(module, directory, origin, check=False): + submodules = [(name, member) for name, member in inspect.getmembers(module) if inspect.ismodule(member)] + + filename = os.path.join(directory, "__init__.pyi") + pyi_content = pyi_file(module) + pyi_content = do_black(pyi_content, is_pyi=True) + os.makedirs(directory, exist_ok=True) + if check: + with open(filename, "r") as f: + data = f.read() + assert data == pyi_content, f"The content of {filename} seems outdated, please run `python stub.py`" + else: + with open(filename, "w") as f: + f.write(pyi_content) + + filename = os.path.join(directory, "__init__.py") + py_content = py_file(module, origin) + py_content = do_black(py_content, is_pyi=False) + os.makedirs(directory, exist_ok=True) + + is_auto = False + if not os.path.exists(filename): + is_auto = True + else: + with open(filename, "r") as f: + line = f.readline() + if line == GENERATED_COMMENT: + is_auto = True + + if is_auto: + if check: + with open(filename, "r") as f: + data = f.read() + assert data == py_content, f"The content of {filename} seems outdated, please run `python stub.py`" + else: + with open(filename, "w") as f: + f.write(py_content) + + for name, submodule in submodules: + write(submodule, os.path.join(directory, name), f"{name}", check=check) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--check", action="store_true") + + args = parser.parse_args() + + #Enable execution from the candle and candle-pyo3 directories + cwd = Path.cwd() + directory = "py_src/candle/" + if cwd.name != "candle-pyo3": + directory = f"candle-pyo3/{directory}" + + import candle + + write(candle.candle, directory, "candle", check=args.check) |