summaryrefslogtreecommitdiff
path: root/candle-pyo3/py_src/candle
diff options
context:
space:
mode:
authorLukas Kreussel <65088241+LLukas22@users.noreply.github.com>2023-10-17 12:07:26 +0200
committerGitHub <noreply@github.com>2023-10-17 11:07:26 +0100
commitf9e93f5b6909b4f680c244a0d049add181675958 (patch)
treee509752e90521d6500eb22e35e56b6322a9b6706 /candle-pyo3/py_src/candle
parentb355ab4e2e52b077e71aac46c286fbce033f36d6 (diff)
downloadcandle-f9e93f5b6909b4f680c244a0d049add181675958.tar.gz
candle-f9e93f5b6909b4f680c244a0d049add181675958.tar.bz2
candle-f9e93f5b6909b4f680c244a0d049add181675958.zip
Extend `stub.py` to accept external typehinting (#1102)
Diffstat (limited to 'candle-pyo3/py_src/candle')
-rw-r--r--candle-pyo3/py_src/candle/__init__.pyi42
-rw-r--r--candle-pyo3/py_src/candle/functional/__init__.pyi2
-rw-r--r--candle-pyo3/py_src/candle/typing/__init__.py4
-rw-r--r--candle-pyo3/py_src/candle/utils/__init__.pyi2
4 files changed, 47 insertions, 3 deletions
diff --git a/candle-pyo3/py_src/candle/__init__.pyi b/candle-pyo3/py_src/candle/__init__.pyi
index 414f0bc4..4096907b 100644
--- a/candle-pyo3/py_src/candle/__init__.pyi
+++ b/candle-pyo3/py_src/candle/__init__.pyi
@@ -1,7 +1,7 @@
# Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
-from candle.typing import _ArrayLike, Device
+from candle.typing import _ArrayLike, Device, Scalar, Index
class bf16(DType):
pass
@@ -119,6 +119,46 @@ class Tensor:
def __init__(self, data: _ArrayLike):
pass
+ def __add__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
+ """
+ Add a scalar to a tensor or two tensors together.
+ """
+ pass
+ def __getitem__(self, index: Union[Index, Tensor, Sequence[Index]]) -> "Tensor":
+ """
+ Return a slice of a tensor.
+ """
+ pass
+ def __mul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
+ """
+ Multiply a tensor by a scalar or one tensor by another.
+ """
+ pass
+ def __radd__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
+ """
+ Add a scalar to a tensor or two tensors together.
+ """
+ pass
+ def __richcmp__(self, rhs: Union[Tensor, Scalar], op) -> "Tensor":
+ """
+ Compare a tensor with a scalar or one tensor with another.
+ """
+ pass
+ def __rmul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
+ """
+ Multiply a tensor by a scalar or one tensor by another.
+ """
+ pass
+ def __sub__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
+ """
+ Subtract a scalar from a tensor or one tensor from another.
+ """
+ pass
+ def __truediv__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
+ """
+ Divide a tensor by a scalar or one tensor by another.
+ """
+ pass
def argmax_keepdim(self, dim: int) -> Tensor:
"""
Returns the indices of the maximum value(s) across the selected dimension.
diff --git a/candle-pyo3/py_src/candle/functional/__init__.pyi b/candle-pyo3/py_src/candle/functional/__init__.pyi
index 6f206e40..5bf5c4c3 100644
--- a/candle-pyo3/py_src/candle/functional/__init__.pyi
+++ b/candle-pyo3/py_src/candle/functional/__init__.pyi
@@ -1,7 +1,7 @@
# Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
-from candle.typing import _ArrayLike, Device
+from candle.typing import _ArrayLike, Device, Scalar, Index
from candle import Tensor, DType, QTensor
@staticmethod
diff --git a/candle-pyo3/py_src/candle/typing/__init__.py b/candle-pyo3/py_src/candle/typing/__init__.py
index ccdb6238..66bc3d8a 100644
--- a/candle-pyo3/py_src/candle/typing/__init__.py
+++ b/candle-pyo3/py_src/candle/typing/__init__.py
@@ -14,3 +14,7 @@ CPU: str = "cpu"
CUDA: str = "cuda"
Device = TypeVar("Device", CPU, CUDA)
+
+Scalar = Union[int, float]
+
+Index = Union[int, slice, None, "Ellipsis"]
diff --git a/candle-pyo3/py_src/candle/utils/__init__.pyi b/candle-pyo3/py_src/candle/utils/__init__.pyi
index 61964ffc..d3b93766 100644
--- a/candle-pyo3/py_src/candle/utils/__init__.pyi
+++ b/candle-pyo3/py_src/candle/utils/__init__.pyi
@@ -1,7 +1,7 @@
# Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
-from candle.typing import _ArrayLike, Device
+from candle.typing import _ArrayLike, Device, Scalar, Index
from candle import Tensor, DType, QTensor
@staticmethod