summaryrefslogtreecommitdiff
path: root/candle-pyo3/stub.py
diff options
context:
space:
mode:
Diffstat (limited to 'candle-pyo3/stub.py')
-rw-r--r--candle-pyo3/stub.py23
1 files changed, 19 insertions, 4 deletions
diff --git a/candle-pyo3/stub.py b/candle-pyo3/stub.py
index b5b9256f..149715c2 100644
--- a/candle-pyo3/stub.py
+++ b/candle-pyo3/stub.py
@@ -13,8 +13,8 @@ TYPING = """from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from os import PathLike
"""
CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device\n"
-CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType\n"
-
+CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType,QTensor\n"
+RETURN_TYPE_MARKER = "&RETURNS&: "
def do_indent(text: Optional[str], indent: str):
@@ -28,11 +28,26 @@ def function(obj, indent:str, text_signature:str=None):
text_signature = obj.__text_signature__
text_signature = text_signature.replace("$self", "self").lstrip().rstrip()
+ doc_string = obj.__doc__
+ if doc_string is None:
+ doc_string = ""
+
+ # Check if we have a return type annotation in the docstring
+ return_type = None
+ doc_lines = doc_string.split("\n")
+ if doc_lines[-1].lstrip().startswith(RETURN_TYPE_MARKER):
+ # Extract the return type and remove it from the docstring
+ return_type = doc_lines[-1].lstrip()[len(RETURN_TYPE_MARKER):].strip()
+ doc_string = "\n".join(doc_lines[:-1])
+
string = ""
- string += f"{indent}def {obj.__name__}{text_signature}:\n"
+ if return_type:
+ string += f"{indent}def {obj.__name__}{text_signature} -> {return_type}:\n"
+ else:
+ string += f"{indent}def {obj.__name__}{text_signature}:\n"
indent += INDENT
string += f'{indent}"""\n'
- string += f"{indent}{do_indent(obj.__doc__, indent)}\n"
+ string += f"{indent}{do_indent(doc_string, indent)}\n"
string += f'{indent}"""\n'
string += f"{indent}pass\n"
string += "\n"