diff options
Diffstat (limited to 'candle-pyo3/stub.py')
-rw-r--r-- | candle-pyo3/stub.py | 23 |
1 files changed, 19 insertions, 4 deletions
diff --git a/candle-pyo3/stub.py b/candle-pyo3/stub.py index b5b9256f..149715c2 100644 --- a/candle-pyo3/stub.py +++ b/candle-pyo3/stub.py @@ -13,8 +13,8 @@ TYPING = """from typing import Any, Callable, Dict, List, Optional, Tuple, Union from os import PathLike """ CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device\n" -CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType\n" - +CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType,QTensor\n" +RETURN_TYPE_MARKER = "&RETURNS&: " def do_indent(text: Optional[str], indent: str): @@ -28,11 +28,26 @@ def function(obj, indent:str, text_signature:str=None): text_signature = obj.__text_signature__ text_signature = text_signature.replace("$self", "self").lstrip().rstrip() + doc_string = obj.__doc__ + if doc_string is None: + doc_string = "" + + # Check if we have a return type annotation in the docstring + return_type = None + doc_lines = doc_string.split("\n") + if doc_lines[-1].lstrip().startswith(RETURN_TYPE_MARKER): + # Extract the return type and remove it from the docstring + return_type = doc_lines[-1].lstrip()[len(RETURN_TYPE_MARKER):].strip() + doc_string = "\n".join(doc_lines[:-1]) + string = "" - string += f"{indent}def {obj.__name__}{text_signature}:\n" + if return_type: + string += f"{indent}def {obj.__name__}{text_signature} -> {return_type}:\n" + else: + string += f"{indent}def {obj.__name__}{text_signature}:\n" indent += INDENT string += f'{indent}"""\n' - string += f"{indent}{do_indent(obj.__doc__, indent)}\n" + string += f"{indent}{do_indent(doc_string, indent)}\n" string += f'{indent}"""\n' string += f"{indent}pass\n" string += "\n" |