summaryrefslogtreecommitdiff
path: root/candle-pyo3/_additional_typing
diff options
context:
space:
mode:
authorLukas Kreussel <65088241+LLukas22@users.noreply.github.com>2023-10-17 12:07:26 +0200
committerGitHub <noreply@github.com>2023-10-17 11:07:26 +0100
commitf9e93f5b6909b4f680c244a0d049add181675958 (patch)
treee509752e90521d6500eb22e35e56b6322a9b6706 /candle-pyo3/_additional_typing
parentb355ab4e2e52b077e71aac46c286fbce033f36d6 (diff)
downloadcandle-f9e93f5b6909b4f680c244a0d049add181675958.tar.gz
candle-f9e93f5b6909b4f680c244a0d049add181675958.tar.bz2
candle-f9e93f5b6909b4f680c244a0d049add181675958.zip
Extend `stub.py` to accept external typehinting (#1102)
Diffstat (limited to 'candle-pyo3/_additional_typing')
-rw-r--r--candle-pyo3/_additional_typing/README.md3
-rw-r--r--candle-pyo3/_additional_typing/__init__.py55
2 files changed, 58 insertions, 0 deletions
diff --git a/candle-pyo3/_additional_typing/README.md b/candle-pyo3/_additional_typing/README.md
new file mode 100644
index 00000000..ab5074e0
--- /dev/null
+++ b/candle-pyo3/_additional_typing/README.md
@@ -0,0 +1,3 @@
+This python module contains external typehinting for certain `candle` classes. This is only necessary for `magic` methodes e.g. `__add__` as their text signature cant be set via pyo3.
+
+The classes in this module will be parsed by the `stub.py` script and interleafed with the signatures of the actual pyo3 `candle.candle` module. \ No newline at end of file
diff --git a/candle-pyo3/_additional_typing/__init__.py b/candle-pyo3/_additional_typing/__init__.py
new file mode 100644
index 00000000..0d0eec90
--- /dev/null
+++ b/candle-pyo3/_additional_typing/__init__.py
@@ -0,0 +1,55 @@
+from typing import Union, Sequence
+
+
+class Tensor:
+ """
+ This contains the type hints for the magic methodes of the `candle.Tensor` class.
+ """
+
+ def __add__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
+ """
+ Add a scalar to a tensor or two tensors together.
+ """
+ pass
+
+ def __radd__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
+ """
+ Add a scalar to a tensor or two tensors together.
+ """
+ pass
+
+ def __sub__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
+ """
+ Subtract a scalar from a tensor or one tensor from another.
+ """
+ pass
+
+ def __truediv__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
+ """
+ Divide a tensor by a scalar or one tensor by another.
+ """
+ pass
+
+ def __mul__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
+ """
+ Multiply a tensor by a scalar or one tensor by another.
+ """
+ pass
+
+ def __rmul__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
+ """
+ Multiply a tensor by a scalar or one tensor by another.
+ """
+ pass
+
+ def __richcmp__(self, rhs: Union["Tensor", "Scalar"], op) -> "Tensor":
+ """
+ Compare a tensor with a scalar or one tensor with another.
+ """
+ pass
+
+ def __getitem__(self, index: Union["Index", "Tensor", Sequence["Index"]]) -> "Tensor":
+ """
+ Return a slice of a tensor.
+ """
+ pass