summaryrefslogtreecommitdiff
path: root/candle-pyo3/py_src/candle/__init__.pyi
diff options
context:
space:
mode:
Diffstat (limited to 'candle-pyo3/py_src/candle/__init__.pyi')
-rw-r--r--candle-pyo3/py_src/candle/__init__.pyi41
1 files changed, 41 insertions, 0 deletions
diff --git a/candle-pyo3/py_src/candle/__init__.pyi b/candle-pyo3/py_src/candle/__init__.pyi
index 35b17680..37b8fe8c 100644
--- a/candle-pyo3/py_src/candle/__init__.pyi
+++ b/candle-pyo3/py_src/candle/__init__.pyi
@@ -124,16 +124,46 @@ class Tensor:
Add a scalar to a tensor or two tensors together.
"""
pass
+ def __eq__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
+ """
+ Compare a tensor with a scalar or one tensor with another.
+ """
+ pass
+ def __ge__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
+ """
+ Compare a tensor with a scalar or one tensor with another.
+ """
+ pass
def __getitem__(self, index: Union[Index, Tensor, Sequence[Index]]) -> "Tensor":
"""
Return a slice of a tensor.
"""
pass
+ def __gt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
+ """
+ Compare a tensor with a scalar or one tensor with another.
+ """
+ pass
+ def __le__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
+ """
+ Compare a tensor with a scalar or one tensor with another.
+ """
+ pass
+ def __lt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
+ """
+ Compare a tensor with a scalar or one tensor with another.
+ """
+ pass
def __mul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Multiply a tensor by a scalar or one tensor by another.
"""
pass
+ def __ne__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
+ """
+ Compare a tensor with a scalar or one tensor with another.
+ """
+ pass
def __radd__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Add a scalar to a tensor or two tensors together.
@@ -159,6 +189,11 @@ class Tensor:
Divide a tensor by a scalar or one tensor by another.
"""
pass
+ def abs(self) -> Tensor:
+ """
+ Performs the `abs` operation on the tensor.
+ """
+ pass
def argmax_keepdim(self, dim: int) -> Tensor:
"""
Returns the indices of the maximum value(s) across the selected dimension.
@@ -308,6 +343,12 @@ class Tensor:
ranges from `start` to `start + len`.
"""
pass
+ @property
+ def nelement(self) -> int:
+ """
+ Gets the tensor's element count.
+ """
+ pass
def powf(self, p: float) -> Tensor:
"""
Performs the `pow` operation on the tensor with the given exponent.